From a317cc57f5bb2fdd37612cef7915ce3eb73c26f5 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 9 Aug 2024 17:24:31 +0200 Subject: [PATCH 01/23] [WIP][FEAT] Add support for optimum-quanto This is unfinished, only pure implementations are provided. TODOs: - [ ] Documentation - [ ] Tests (should work on CPU!) - [ ] Whether Conv2d works is not verified yet - [ ] Optional: DoRA support - [ ] Optional: Mixed adapter batches support --- src/peft/import_utils.py | 5 + src/peft/tuners/lora/model.py | 2 + src/peft/tuners/lora/quanto.py | 376 +++++++++++++++++++++++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 src/peft/tuners/lora/quanto.py diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 58c65f9c80..c98e256f35 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -87,3 +87,8 @@ def is_eetq_available(): @lru_cache def is_hqq_available(): return importlib.util.find_spec("hqq") is not None + + +@lru_cache +def is_quanto_available(): + return importlib.util.find_spec("optimum.quanto") is not None diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 9d3f3bf62f..97b0db4a30 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -53,6 +53,7 @@ from .gptq import dispatch_gptq from .hqq import dispatch_hqq from .layer import Conv2d, LoraLayer, dispatch_default +from .quanto import dispatch_quanto from .tp_layer import dispatch_megatron @@ -330,6 +331,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): dispatch_awq, dispatch_gptq, dispatch_hqq, + dispatch_quanto, dispatch_megatron, dispatch_default, ] diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py new file mode 100644 index 0000000000..9c0e06bcb9 --- /dev/null +++ b/src/peft/tuners/lora/quanto.py @@ -0,0 +1,376 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from peft.import_utils import is_quanto_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + + +if is_quanto_available: + # ensure that there are no quanto imports unless optimum.quanto is installed + from optimum.quanto import QConv2d, QLinear +else: + QConv2d, QLinear = None, None + + +class QuantoLoraLinear(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QLinear""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + result = self.base_layer(x) + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is not None: + raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + + if self.disable_adapters: + return result + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + return ( + transpose(self.lora_B[adapter].weight @ self.lora_A[adapter].weight, fan_in_fan_out=self.fan_in_fan_out) + * self.scaling[adapter] + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + from optimum.quanto import quantize_weight + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + + for active_adapter in adapter_names: + delta_weight = self.get_delta_weight(active_adapter) + # note: no in-place for safe_merge=False + new_weight_data = orig_weight + delta_weight + if safe_merge: + if torch.isfinite(new_weight_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + from optimum.quanto import quantize_weight + + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + new_weight_data = orig_weight - self.get_delta_weight(active_adapter) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class QuantoLoraConv2d(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QConv2d""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + # same as lora.layer.Conv2d + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + # call this before dora_init + self._move_adapter_to_device_of_base_layer(adapter_name) + + if use_dora: + # TODO: Implement DoRA + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + result = self.base_layer(x) + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is not None: + raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + + if self.disable_adapters: + return result + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + # same as lora.layer.Conv2d + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + F.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + + for active_adapter in adapter_names: + delta_weight = self.get_delta_weight(active_adapter) + # note: no in-place for safe_merge=False + new_weight_data = orig_weight + delta_weight + if safe_merge: + if torch.isfinite(new_weight_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + new_weight_data = orig_weight - self.get_delta_weight(active_adapter) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_quanto( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_quanto_available() and isinstance(target_base_layer, QLinear): + new_module = QuantoLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + elif is_quanto_available() and isinstance(target_base_layer, QConv2d): + new_module = QuantoLoraConv2d(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module From caad385a04e8d91cc7ef6373b354fc31dbbaf7bb Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 2 Sep 2024 18:02:30 +0200 Subject: [PATCH 02/23] Add some unit tests --- src/peft/tuners/lora/quanto.py | 4 +- tests/test_quanto.py | 488 +++++++++++++++++++++++++++++++++ tests/testing_common.py | 4 + 3 files changed, 494 insertions(+), 2 deletions(-) create mode 100644 tests/test_quanto.py diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index 9c0e06bcb9..2be06781b4 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -123,7 +123,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + quantized = quantize_weight(new_weight_data, qtype=base_layer.qweight.qtype, axis=base_layer.qweight.axis) base_layer.weight._data = quantized._data base_layer.weight._scale = quantized._scale self.merged_adapters.append(active_adapter) @@ -143,7 +143,7 @@ def unmerge(self) -> None: base_layer = self.get_base_layer() orig_weight = base_layer.weight new_weight_data = orig_weight - self.get_delta_weight(active_adapter) - quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + quantized = quantize_weight(new_weight_data, qtype=base_layer.qweight.qtype, axis=base_layer.qweight.axis) base_layer.weight._data = quantized._data base_layer.weight._scale = quantized._scale diff --git a/tests/test_quanto.py b/tests/test_quanto.py new file mode 100644 index 0000000000..9571eb6b56 --- /dev/null +++ b/tests/test_quanto.py @@ -0,0 +1,488 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO describe this test module +import unittest +from unittest.mock import Mock, call, patch + +import pytest +import torch +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer + +from peft import ( + AdaLoraConfig, + BOFTConfig, + HRAConfig, + LoraConfig, + PrefixTuningConfig, + PromptTuningConfig, + PromptTuningInit, + get_peft_model, +) + +from .testing_common import PeftCommonTester, PeftTestConfigManager + + +PEFT_DECODER_MODELS_TO_TEST = [ + "hf-internal-testing/tiny-random-OPTForCausalLM", + # "hf-internal-testing/tiny-random-GPT2LMHeadModel", + # "trl-internal-testing/tiny-random-LlamaForCausalLM", + # "peft-internal-testing/tiny-dummy-qwen2", +] + +FULL_GRID = { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "task_type": "CAUSAL_LM", +} + + +def skip_adalora_and_gpt2(test_list): + return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] + + +def skip_boft_or_hra_and_gpt2(test_list): + return [ + test + for test in test_list + if not (("GPT2LMHeadModel" in test[1]) and ((test[2] == BOFTConfig) or (test[2] == HRAConfig))) + ] + + +def skip_adalora_or_boft_or_hra_and_gpt2(test_list): + return [ + test + for test in test_list + if not ( + ("GPT2LMHeadModel" in test[1]) + and ((test[2] == AdaLoraConfig) or (test[2] == BOFTConfig) or (test[2] == HRAConfig)) + ) + ] + + +def make_automodel_proxy(weights: str): + """Instantiate a quanto-quantized transformers model. + + As quanto is not yet integrated into transformers itself, this is done manually for now but should be replaced once + transformers supports it. + + """ + # TODO: switch to QuantoConfig once https://github.com/huggingface/transformers/pull/31732 is merged + + from optimum.quanto import qfloat8, qint2, qint4, qint8, quantize + from transformers.utils.quantization_config import QuantizationMethod + + class QuantoModelProxy: + @classmethod + def from_pretrained(self, *args, **kwargs): + model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) + if weights == "int2": + quantize(model, weights=qint2) + elif weights == "int4": + quantize(model, weights=qint4) + elif weights == "int8": + quantize(model, weights=qint8) + elif weights == "float8": + quantize(model, weights=qfloat8) + else: + raise ValueError(f"Invalid quantization dtype for quanto: {weights}") + + model.quantization_method = QuantizationMethod.QUANTO + return model + + return QuantoModelProxy + + +class BasePeftQuantoModelTester: + r"""TODO""" + + def prepare_inputs_for_testing(self): + input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) + attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + input_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + return input_dict + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): + self._test_model_attr(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): + self._test_adapter_name(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): + self._test_prepare_for_training(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_prompt_tuning_text_prepare_for_training(self, test_name, model_id, config_cls, config_kwargs): + # Test that prompt tuning works with text init + if config_cls != PromptTuningConfig: + return pytest.skip(f"This test does not apply to {config_cls}") + + config_kwargs = config_kwargs.copy() + config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT + config_kwargs["prompt_tuning_init_text"] = "This is a test prompt." + config_kwargs["tokenizer_name_or_path"] = model_id + self._test_prepare_for_training(model_id, config_cls, config_kwargs) + + def test_prompt_tuning_text_tokenizer_kwargs(self): + # Allow users to pass additional arguments to Tokenizer.from_pretrained + # Fix for #1032 + mock = Mock() + orig_from_pretrained = AutoTokenizer.from_pretrained + + def mock_autotokenizer_from_pretrained(*args, **kwargs): + mock(*args, **kwargs) + return orig_from_pretrained(config.tokenizer_name_or_path) + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + config = PromptTuningConfig( + base_model_name_or_path=model_id, + tokenizer_name_or_path=model_id, + num_virtual_tokens=10, + prompt_tuning_init=PromptTuningInit.TEXT, + task_type="CAUSAL_LM", + prompt_tuning_init_text="This is a test prompt.", + tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, + ) + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + with patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained): + model = get_peft_model(model, config) + + expected_call = call(model_id, trust_remote_code=True, foo="bar") + assert mock.call_args == expected_call + + def test_prompt_tuning_config_invalid_args(self): + # Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no + # function in that case + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + with pytest.raises(ValueError, match="tokenizer_kwargs only valid when using prompt_tuning_init='TEXT'."): + PromptTuningConfig( + base_model_name_or_path=model_id, + tokenizer_name_or_path=model_id, + num_virtual_tokens=10, + task_type="CAUSAL_LM", + prompt_tuning_init_text="This is a test prompt.", + prompt_tuning_init=PromptTuningInit.RANDOM, # <= should not be used together with tokenizer_kwargs + tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, + ) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): + self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=skip_boft_or_hra_and_gpt2, + ) + ) + def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers_multi(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers_nan(model_id, config_cls, config_kwargs) + + # TODO: enable if/when mixed batch inference is supported + # @parameterized.expand( + # PeftTestConfigManager.get_grid_parameters( + # { + # "model_ids": PEFT_DECODER_MODELS_TO_TEST, + # "lora_kwargs": {"init_lora_weights": [False]}, + # "task_type": "CAUSAL_LM", + # }, + # ) + # ) + # def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): + # self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_generate(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): + # positional args are supported for PeftModelForCausalLM + self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate_half_prec(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): + self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): + self._test_training(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_layer_indexing(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): + self._test_inference_safetensors(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): + self._test_peft_model_device_map(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): + self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=skip_adalora_or_boft_or_hra_and_gpt2, + ) + ) + def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_unload_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + def test_weighted_combination_of_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_weighted_combination_of_adapters(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_prompt_learning_tasks(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=skip_boft_or_hra_and_gpt2, + ) + ) + def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_disable_adapter(model_id, config_cls, config_kwargs) + + def test_generate_adalora_no_dropout(self): + # test for issue #730 + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + config_kwargs = { + "target_modules": None, + "task_type": "CAUSAL_LM", + "lora_dropout": 0.0, + } + self._test_generate(model_id, AdaLoraConfig, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) + def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): + self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) + + def test_lora_layer_replication(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + config_kwargs = { + "target_modules": ["down_proj", "up_proj"], + "task_type": "CAUSAL_LM", + "lora_dropout": 0.0, + "layer_replication": [[0, 1], [0, 2], [1, 2]], + } + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = LoraConfig( + base_model_name_or_path=model_id, + **config_kwargs, + ) + assert len(model.model.layers), "Expected 2 layers in original model." == 2 + model = get_peft_model(model, config) + layers = model.base_model.model.model.layers + assert len(layers) == 4, "Expected 4 layers in adapted model." + assert ( + layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + == layers[1].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + and layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + == layers[3].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + ), "Expected layers 0-1 and 2-3 to share weights" + assert ( + layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + != layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + ), "Expected layers 0 and 2 to have different weights" + assert ( + layers[0].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + != layers[1].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + and layers[2].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + != layers[3].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + ), "Expected all LoRA adapters to have distinct weights" + assert ( + len([n for n, _ in model.named_parameters() if ".lora_A." in n]) == 8 + ), "Expected 8 LoRA adapters since we are adding one each for up and down." + self._test_prepare_for_training(model_id, LoraConfig, config_kwargs) + self._test_generate(model_id, LoraConfig, config_kwargs) + + def test_prompt_learning_with_grouped_query_attention(self): + # See 1901, fixes a bug with handling GQA + model_id = "peft-internal-testing/tiny-dummy-qwen2" + base_model = AutoModelForCausalLM.from_pretrained(model_id) + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(base_model, peft_config) + x = torch.tensor([[1, 2, 3]]) + # does not raise + model(x) + + +class PeftQuanto4bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): + r"""TODO""" + + transformers_class = make_automodel_proxy(weights="int4") + + +class PeftQuanto8bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): + r"""TODO""" + + transformers_class = make_automodel_proxy(weights="int8") + +# TODO: qint2, qfloat8 diff --git a/tests/testing_common.py b/tests/testing_common.py index da58337bc2..be88effd83 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -610,6 +610,10 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): if (config.peft_type == "IA3") and (model_id == "Conv2d"): # for some reason, the IA³ Conv2d introduces a larger error atol, rtol = 0.3, 0.01 + if quant_method := getattr(model, "quantization_method", None): + if quant_method.value == "quanto": + atol, rtol = 5e-3, 5e-3 + assert torch.allclose(logits, logits_merged, atol=atol, rtol=rtol) assert torch.allclose(logits, logits_unmerged, atol=atol, rtol=rtol) assert torch.allclose(logits, logits_merged_unloaded, atol=atol, rtol=rtol) From 44d77b41eeb134036e8aafbfa86455836dff1d6b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 3 Sep 2024 16:44:29 +0200 Subject: [PATCH 03/23] More progress on tests, but still many fail --- src/peft/tuners/lora/quanto.py | 9 +- tests/test_quanto.py | 229 +++++++++++++++++++++++++++++++-- 2 files changed, 223 insertions(+), 15 deletions(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index 2be06781b4..c95bb10ea0 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -118,11 +118,10 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N delta_weight = self.get_delta_weight(active_adapter) # note: no in-place for safe_merge=False new_weight_data = orig_weight + delta_weight - if safe_merge: - if torch.isfinite(new_weight_data).all(): - raise ValueError( - f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" - ) + if safe_merge and not torch.isfinite(new_weight_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) quantized = quantize_weight(new_weight_data, qtype=base_layer.qweight.qtype, axis=base_layer.qweight.axis) base_layer.weight._data = quantized._data base_layer.weight._scale = quantized._scale diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 9571eb6b56..49a3cb1c5b 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -13,6 +13,9 @@ # limitations under the License. # TODO describe this test module +import copy +import shutil +import tempfile import unittest from unittest.mock import Mock, call, patch @@ -78,9 +81,10 @@ def make_automodel_proxy(weights: str): transformers supports it. """ - # TODO: switch to QuantoConfig once https://github.com/huggingface/transformers/pull/31732 is merged - - from optimum.quanto import qfloat8, qint2, qint4, qint8, quantize + # TODO: Can't use `from transformers import QuantoConfig` because it checks for the quanto package, but quanto is + # now part of optimum, resulting in the check to fail. + # Switch to QuantoConfig once https://github.com/huggingface/transformers/pull/31732 is merged + from optimum.quanto import QuantizedModelForCausalLM, qfloat8, qint2, qint4, qint8 from transformers.utils.quantization_config import QuantizationMethod class QuantoModelProxy: @@ -88,13 +92,13 @@ class QuantoModelProxy: def from_pretrained(self, *args, **kwargs): model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) if weights == "int2": - quantize(model, weights=qint2) + QuantizedModelForCausalLM.quantize(model, weights=qint2) elif weights == "int4": - quantize(model, weights=qint4) + QuantizedModelForCausalLM.quantize(model, weights=qint4) elif weights == "int8": - quantize(model, weights=qint8) + QuantizedModelForCausalLM.quantize(model, weights=qint8) elif weights == "float8": - quantize(model, weights=qfloat8) + QuantizedModelForCausalLM.quantize(model, weights=qfloat8) else: raise ValueError(f"Invalid quantization dtype for quanto: {weights}") @@ -107,6 +111,9 @@ def from_pretrained(self, *args, **kwargs): class BasePeftQuantoModelTester: r"""TODO""" + # expected minimum correlation between logits before and after merging + min_correlation = 0.95 + def prepare_inputs_for_testing(self): input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) @@ -220,6 +227,9 @@ def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, con def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) + def get_correlation_matrix(self, *tensors): + return torch.corrcoef(torch.stack([t.flatten() for t in tensors])) + @parameterized.expand( PeftTestConfigManager.get_grid_parameters( { @@ -236,7 +246,33 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c ) ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - self._test_merge_layers(model_id, config_cls, config_kwargs) + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking + # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + model.eval() + logits = model(**dummy_input)[0] + + model.merge_adapter() + logits_merged = model(**dummy_input)[0] + model.unmerge_adapter() + logits_unmerged = model(**dummy_input)[0] + + model = model.merge_and_unload() + logits_merged_unloaded = model(**dummy_input)[0] + + cc_matrix = self.get_correlation_matrix(logits, logits_merged, logits_unmerged, logits_merged_unloaded) + assert cc_matrix.min() > self.min_correlation @parameterized.expand( PeftTestConfigManager.get_grid_parameters( @@ -253,8 +289,67 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): filter_params_func=skip_boft_or_hra_and_gpt2, ) ) + # TODO: enable if/when deepcopy-ing is supported + @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): - self._test_merge_layers_multi(model_id, config_cls, config_kwargs) + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking + # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + + # NOTE: don't use with `torch.inference_mode()`, see: https://github.com/huggingface/optimum-quanto/issues/304 + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + model.eval() + + logits_adapter_1 = model(**dummy_input)[0] + + model.add_adapter("adapter-2", config) + model.set_adapter("adapter-2") + model.eval() + + logits_adapter_2 = model(**dummy_input)[0] + + assert not torch.allclose(logits_adapter_1, logits_adapter_2, atol=1e-3, rtol=1e-3) + + model.set_adapter("default") + + logits_adapter_1_after_set = model(**dummy_input)[0] + + cc_matrix = self.get_correlation_matrix(logits_adapter_1, logits_adapter_1_after_set) + assert cc_matrix.min() > self.min_correlation + + model_copy = copy.deepcopy(model) + model_copy_2 = copy.deepcopy(model) + model_merged_all = model.merge_and_unload(adapter_names=["adapter-2", "default"]) + + logits_merged_all = model_merged_all(**dummy_input)[0] + + assert not torch.allclose(logits_merged_all, logits_adapter_2, atol=1e-3, rtol=1e-3) + assert not torch.allclose(logits_merged_all, logits_adapter_1, atol=1e-3, rtol=1e-3) + + model_merged_adapter_2 = model_copy.merge_and_unload(adapter_names=["adapter-2"]) + + logits_merged_adapter_2 = model_merged_adapter_2(**dummy_input)[0] + + cc_matrix = self.get_correlation_matrix(logits_adapter_2, logits_merged_adapter_2) + assert cc_matrix.min() > self.min_correlation + + model_merged_adapter_default = model_copy_2.merge_and_unload(adapter_names=["default"]) + + logits_merged_adapter_default = model_merged_adapter_default(**dummy_input)[0] + + cc_matrix = self.get_correlation_matrix(logits_adapter_1, logits_merged_adapter_default) + assert cc_matrix.min() > self.min_correlation @parameterized.expand( PeftTestConfigManager.get_grid_parameters( @@ -268,7 +363,117 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs ) ) def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): - self._test_merge_layers_nan(model_id, config_cls, config_kwargs) + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking + # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + + model.eval() + + # This should work + logits_unmerged = model(**dummy_input)[0] + + model = model.merge_and_unload() + logits_merged = model(**dummy_input)[0] + + cc_matrix = self.get_correlation_matrix(logits_unmerged, logits_merged) + assert cc_matrix.min() > self.min_correlation + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + prefixes = ["lora_A", "boft_R", "fourierft_spectrum", "hra_u", "hada_w1", "lokr_w1", "ia3_l", "oft_r"] + prefixes += ["vera_lambda_b"] + + for name, module in model.named_parameters(): + if any(prefix in name for prefix in prefixes): + module.data[0] = torch.nan + + with pytest.raises( + ValueError, match="NaNs detected in the merged weights. The adapter default seems to be broken" + ): + model = model.merge_and_unload(safe_merge=True) + + for name, module in model.named_parameters(): + if any(prefix in name for prefix in prefixes): + module.data[0] = torch.inf + + with pytest.raises( + ValueError, match="NaNs detected in the merged weights. The adapter default seems to be broken" + ): + model = model.merge_and_unload(safe_merge=True) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + @pytest.mark.xfail(strict=True) + def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, config_kwargs): + # Saving and loading a quanto model that has been merged and unloaded does not work correctly. Here is the + # reason: Quanto requires its own save_pretrained method, which, among others, saves the quantization map. + # Without it, the model cannot be correctly loaded. To make use of this, we should thus use a quanto + # QuantizedModel instance instead of a PretrainedModel instance. However, the QuantizedModel instance cannot be + # used for anything else, e.g. it has no __call__ method. Therefore, we cannot use that in PEFT. Therefore, + # users need to pass the PretrainedModel instance to get_peft_model, thus we don't have the modified + # save_pretrained, thus loading the merged and unloaded model does not work. + from optimum.quanto import QuantizedModelForCausalLM + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + model = model.merge_and_unload() + model.eval() + + dummy_input = self.prepare_inputs_for_testing() + logits = model(**dummy_input)[0] + + # model is a transformers model + tmp_dirname = tempfile.mkdtemp() + # note: not using the context manager here because it fails on Windows CI for some reason + try: + model.save_pretrained(tmp_dirname) + # Carefuly: must use QuantizedModelForCausalLM.from_pretrained not AutoModelForCausalLM.from_pretrained + model_from_pretrained = QuantizedModelForCausalLM.from_pretrained(tmp_dirname).to(self.torch_device) + finally: + try: + shutil.rmtree(tmp_dirname) + except PermissionError: + # windows error + pass + + logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] + cc_matrix = self.get_correlation_matrix(logits, logits_merged_from_pretrained) + assert cc_matrix.min() > self.min_correlation # TODO: enable if/when mixed batch inference is supported # @parameterized.expand( @@ -305,6 +510,7 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs self._test_generate_half_prec(model_id, config_cls, config_kwargs) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) @@ -424,6 +630,8 @@ def test_generate_adalora_no_dropout(self): def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) + # TODO: enable if/when deepcopy-ing is supported + @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") def test_lora_layer_replication(self): model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" config_kwargs = { @@ -485,4 +693,5 @@ class PeftQuanto8bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQua transformers_class = make_automodel_proxy(weights="int8") + # TODO: qint2, qfloat8 From d334eb3759246bda799deb571c3737f63cab965b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 4 Sep 2024 16:01:15 +0200 Subject: [PATCH 04/23] Skip merge tests that are not LoRA Other methods would have to explicitly support quanto for these tests to pass. --- tests/test_quanto.py | 124 ++++++++++++------------------------------- 1 file changed, 33 insertions(+), 91 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 49a3cb1c5b..81387d530b 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -25,9 +25,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from peft import ( - AdaLoraConfig, - BOFTConfig, - HRAConfig, LoraConfig, PrefixTuningConfig, PromptTuningConfig, @@ -38,6 +35,14 @@ from .testing_common import PeftCommonTester, PeftTestConfigManager +# only the PEFT methods that are explicitly supported will be tested +PEFT_METHODS_SUPPORTING_MERGING = [LoraConfig] + + +def filter_supported_methods_supporting_merging(test_list): + return [test for test in test_list if any(test[2] is cls for cls in PEFT_METHODS_SUPPORTING_MERGING)] + + PEFT_DECODER_MODELS_TO_TEST = [ "hf-internal-testing/tiny-random-OPTForCausalLM", # "hf-internal-testing/tiny-random-GPT2LMHeadModel", @@ -51,29 +56,6 @@ } -def skip_adalora_and_gpt2(test_list): - return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] - - -def skip_boft_or_hra_and_gpt2(test_list): - return [ - test - for test in test_list - if not (("GPT2LMHeadModel" in test[1]) and ((test[2] == BOFTConfig) or (test[2] == HRAConfig))) - ] - - -def skip_adalora_or_boft_or_hra_and_gpt2(test_list): - return [ - test - for test in test_list - if not ( - ("GPT2LMHeadModel" in test[1]) - and ((test[2] == AdaLoraConfig) or (test[2] == BOFTConfig) or (test[2] == HRAConfig)) - ) - ] - - def make_automodel_proxy(weights: str): """Instantiate a quanto-quantized transformers model. @@ -125,21 +107,15 @@ def prepare_inputs_for_testing(self): return input_dict - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_model_attr(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): self._test_adapter_name(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_prepare_for_training(model_id, config_cls, config_kwargs) @@ -197,33 +173,23 @@ def test_prompt_tuning_config_invalid_args(self): tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, ) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) @@ -243,6 +209,7 @@ def get_correlation_matrix(self, *tensors): "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=filter_supported_methods_supporting_merging, ) ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): @@ -286,7 +253,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_boft_or_hra_and_gpt2, + filter_params_func=filter_supported_methods_supporting_merging, ) ) # TODO: enable if/when deepcopy-ing is supported @@ -360,6 +327,7 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs "boft_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=filter_supported_methods_supporting_merging, ) ) def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): @@ -431,6 +399,7 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=filter_supported_methods_supporting_merging, ) ) @pytest.mark.xfail(strict=True) @@ -488,20 +457,19 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co # def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): # self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): # positional args are supported for PeftModelForCausalLM self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID), + filter_params_func=filter_supported_methods_supporting_merging, + ) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) @@ -514,9 +482,7 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): self._test_training(model_id, config_cls, config_kwargs) @@ -524,15 +490,11 @@ def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs) def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_layer_indexing(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): self._test_inference_safetensors(model_id, config_cls, config_kwargs) @@ -540,21 +502,15 @@ def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwa def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): self._test_peft_model_device_map(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) @@ -571,7 +527,6 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_adalora_or_boft_or_hra_and_gpt2, ) ) def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -608,25 +563,12 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_boft_or_hra_and_gpt2, ) ) def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_disable_adapter(model_id, config_cls, config_kwargs) - def test_generate_adalora_no_dropout(self): - # test for issue #730 - model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" - config_kwargs = { - "target_modules": None, - "task_type": "CAUSAL_LM", - "lora_dropout": 0.0, - } - self._test_generate(model_id, AdaLoraConfig, config_kwargs) - - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) - ) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) From 6d8b07196e37f043b117228ca139d0bdd400effe Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 4 Sep 2024 16:28:19 +0200 Subject: [PATCH 05/23] Add tests for int2 float8 produces nans, so not used right now. --- tests/test_quanto.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 81387d530b..e11b00e25f 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -94,7 +94,8 @@ class BasePeftQuantoModelTester: r"""TODO""" # expected minimum correlation between logits before and after merging - min_correlation = 0.95 + # subclasses should override this with a float between 0 and 1 + min_correlation = "missing" def prepare_inputs_for_testing(self): input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) @@ -215,6 +216,8 @@ def get_correlation_matrix(self, *tensors): def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + torch.manual_seed(0) + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -261,8 +264,9 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. - # NOTE: don't use with `torch.inference_mode()`, see: https://github.com/huggingface/optimum-quanto/issues/304 + torch.manual_seed(0) + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -333,6 +337,8 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + torch.manual_seed(0) + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -413,6 +419,8 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co # save_pretrained, thus loading the merged and unloaded model does not work. from optimum.quanto import QuantizedModelForCausalLM + torch.manual_seed(0) + model = self.transformers_class.from_pretrained(model_id) config = config_cls( base_model_name_or_path=model_id, @@ -624,16 +632,16 @@ def test_prompt_learning_with_grouped_query_attention(self): model(x) -class PeftQuanto4bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): - r"""TODO""" +class PeftQuanto2bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): + transformers_class = make_automodel_proxy(weights="int2") + min_correlation = 0.9 + +class PeftQuanto4bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): transformers_class = make_automodel_proxy(weights="int4") + min_correlation = 0.95 class PeftQuanto8bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): - r"""TODO""" - transformers_class = make_automodel_proxy(weights="int8") - - -# TODO: qint2, qfloat8 + min_correlation = 0.95 From c4cc6daba1fdf5942c6ad035b37954dc0034c2f6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 4 Sep 2024 16:50:32 +0200 Subject: [PATCH 06/23] Add test for conv2d --- tests/test_quanto.py | 67 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index e11b00e25f..f1916c3380 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -22,6 +22,7 @@ import pytest import torch from parameterized import parameterized +from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer from peft import ( @@ -631,6 +632,72 @@ def test_prompt_learning_with_grouped_query_attention(self): # does not raise model(x) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_quanto_merge_conv2d(self, test_name, model_id, config_cls, config_kwargs): + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + config.target_modules = {"conv2d"} + config.task_type = None + + class ModelConv2D(nn.Module): + def __init__(self): + super().__init__() + self.conv2d = nn.Conv2d(5, 10, 3) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin0 = nn.Linear(10, 2) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float().reshape(-1, 5, 3, 3) + X = self.conv2d(X) + X = self.relu(X) + X = self.flat(X) + X = self.lin0(X) + X = self.sm(X) + return X + + model = ModelConv2D() + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = torch.arange(90).view(9, 10).to(self.torch_device) + model.eval() + logits = model(dummy_input)[0] + + model.merge_adapter() + logits_merged = model(dummy_input)[0] + model.unmerge_adapter() + logits_unmerged = model(dummy_input)[0] + + model = model.merge_and_unload() + logits_merged_unloaded = model(dummy_input)[0] + + cc_matrix = self.get_correlation_matrix(logits, logits_merged, logits_unmerged, logits_merged_unloaded) + assert cc_matrix.min() > self.min_correlation + class PeftQuanto2bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): transformers_class = make_automodel_proxy(weights="int2") From c50c7c60eab76b40e2a1a0f7f0c039707f2593a8 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 4 Sep 2024 16:51:32 +0200 Subject: [PATCH 07/23] Add some quanto docs --- docs/source/developer_guides/quantization.md | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index 114021cafc..381d6e27fe 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -187,6 +187,29 @@ peft_config = LoraConfig(...) quantized_model = get_peft_model(quantized_model, peft_config) ``` +## Optimum-quanto + +PEFT supports models quantized with [optimum-quanto](https://github.com/huggingface/optimum-quanto). This has been tested with 2bit, 4bit, and 8bit int quantization. Optimum-quanto also works on CPU and MPS. + +```python +from transformers import AutoModelForCausalLM +from optimum.quanto import QuantizedModelForCausalLM, qint2, qint4, qint8 + +model_id = ... +base_model = AutoModelForCausalLM.from_pretrained(model_id) +QuantizedModelForCausalLM.quantize(model, weights=qint4) # or qint2 or qint8 +peft_config = LoraConfig(...) +model = get_peft_model(base_model, peft_config) +``` + + + +### Caveats: + +- Use a version > 2.4.0, otherwise saving and loading won't work properly. +- Float8 is discouraged as it can easily produce NaNs. +- There is explicit support for optimum-quanto when used with LoRA. However, when optimum-quanto quantizes a layer, it remains a subclass of the corresponding torch class (e.g., quanto's `QLinear` is a subclass of `nn.Linear`). For this reason, methods will generally also work with optimum-quanto, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA**. If you use a method other than LoRA, merging may not raise an error but the results will be incorrect. + ## Next steps If you're interested in learning more about quantization, the following may be helpful: From 8cece29a6de5794537e910d85735e8515d955cc8 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 5 Sep 2024 12:37:19 +0200 Subject: [PATCH 08/23] More fixes to quanto tests, should now pass --- src/peft/tuners/lora/quanto.py | 2 +- tests/test_quanto.py | 127 +++++++++++++++++++-------------- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index c95bb10ea0..1dda88dd18 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -173,7 +173,7 @@ def __init__( LoraLayer.__init__(self, base_layer) self._active_adapter = adapter_name - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora) def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): # same as lora.layer.Conv2d diff --git a/tests/test_quanto.py b/tests/test_quanto.py index f1916c3380..a858588c23 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# TODO describe this test module import copy import shutil import tempfile @@ -36,7 +34,7 @@ from .testing_common import PeftCommonTester, PeftTestConfigManager -# only the PEFT methods that are explicitly supported will be tested +# only the PEFT methods that are explicitly supported will be tested for merging PEFT_METHODS_SUPPORTING_MERGING = [LoraConfig] @@ -67,7 +65,7 @@ def make_automodel_proxy(weights: str): # TODO: Can't use `from transformers import QuantoConfig` because it checks for the quanto package, but quanto is # now part of optimum, resulting in the check to fail. # Switch to QuantoConfig once https://github.com/huggingface/transformers/pull/31732 is merged - from optimum.quanto import QuantizedModelForCausalLM, qfloat8, qint2, qint4, qint8 + from optimum.quanto import QuantizedModelForCausalLM, qint2, qint4, qint8 from transformers.utils.quantization_config import QuantizationMethod class QuantoModelProxy: @@ -80,9 +78,8 @@ def from_pretrained(self, *args, **kwargs): QuantizedModelForCausalLM.quantize(model, weights=qint4) elif weights == "int8": QuantizedModelForCausalLM.quantize(model, weights=qint8) - elif weights == "float8": - QuantizedModelForCausalLM.quantize(model, weights=qfloat8) else: + # float8 was tried but was producing NaNs raise ValueError(f"Invalid quantization dtype for quanto: {weights}") model.quantization_method = QuantizationMethod.QUANTO @@ -92,11 +89,42 @@ def from_pretrained(self, *args, **kwargs): class BasePeftQuantoModelTester: - r"""TODO""" + r"""Base class implementing tests for quanto-quantized models. + + This class is based on PeftDecoderModelTester with some quanto-specific edits, especially for the merging tests, + which are less precise due to the quantization. + Subclasses should implement the attributes below. + """ + + # The weights argument for quanto, should be "int2", "int4", or "int8" + weights = "MISSING" + # transformers class should be make_automodel_proxy(weights=weights) + transformers_class = "MISSING" # expected minimum correlation between logits before and after merging # subclasses should override this with a float between 0 and 1 - min_correlation = "missing" + min_correlation = "MISSING" + # the allowed tolerance for comparing the output tensors + tol = "MISSING" + + def _get_correlation_matrix(self, *tensors): + return torch.corrcoef(torch.stack([t.flatten() for t in tensors])) + + def check_tensors_approximately_equal(self, *tensors): + # Strict equality checks will fail due to the quantization, so we check: + # 1. The correlation between the tensors is high + # 2. Tensor equality after removing 1% of highest and lowest outliers + cc_matrix = self._get_correlation_matrix(*tensors) + assert cc_matrix.min() > self.min_correlation + + for tensor0, tensor1 in zip(tensors, tensors[1:]): + tensor0, tensor1 = tensor0.flatten(), tensor1.flatten() + diff = tensor0 - tensor1 + indices = torch.argsort(diff) + # remove 1% outliers on both ends + indices = indices[len(indices) // 100 : -len(indices) // 100] + tensor0, tensor1 = tensor0[indices], tensor1[indices] + assert torch.allclose(tensor0, tensor1, atol=self.tol, rtol=self.tol) def prepare_inputs_for_testing(self): input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) @@ -195,9 +223,6 @@ def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, con def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) - def get_correlation_matrix(self, *tensors): - return torch.corrcoef(torch.stack([t.flatten() for t in tensors])) - @parameterized.expand( PeftTestConfigManager.get_grid_parameters( { @@ -215,8 +240,8 @@ def get_correlation_matrix(self, *tensors): ) ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking - # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal torch.manual_seed(0) config = config_cls( @@ -242,8 +267,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): model = model.merge_and_unload() logits_merged_unloaded = model(**dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits, logits_merged, logits_unmerged, logits_merged_unloaded) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) @parameterized.expand( PeftTestConfigManager.get_grid_parameters( @@ -263,8 +287,8 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): # TODO: enable if/when deepcopy-ing is supported @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): - # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking - # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal # NOTE: don't use with `torch.inference_mode()`, see: https://github.com/huggingface/optimum-quanto/issues/304 torch.manual_seed(0) @@ -297,8 +321,7 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs logits_adapter_1_after_set = model(**dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits_adapter_1, logits_adapter_1_after_set) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits_adapter_1, logits_adapter_1_after_set) model_copy = copy.deepcopy(model) model_copy_2 = copy.deepcopy(model) @@ -313,15 +336,12 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs logits_merged_adapter_2 = model_merged_adapter_2(**dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits_adapter_2, logits_merged_adapter_2) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits_adapter_2, logits_merged_adapter_2) model_merged_adapter_default = model_copy_2.merge_and_unload(adapter_names=["default"]) - logits_merged_adapter_default = model_merged_adapter_default(**dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits_adapter_1, logits_merged_adapter_default) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits_adapter_1, logits_merged_adapter_default) @parameterized.expand( PeftTestConfigManager.get_grid_parameters( @@ -336,8 +356,8 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs ) ) def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): - # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking - # torch.allclose of logits, we calculate the coeffcient of correlation. This has some precedence, like for HQQ. + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal torch.manual_seed(0) config = config_cls( @@ -361,8 +381,7 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): model = model.merge_and_unload() logits_merged = model(**dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits_unmerged, logits_merged) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits_unmerged, logits_merged) model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -450,8 +469,7 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co pass logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits, logits_merged_from_pretrained) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits, logits_merged_from_pretrained) # TODO: enable if/when mixed batch inference is supported # @parameterized.expand( @@ -482,9 +500,10 @@ def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): - self._test_generate_half_prec(model_id, config_cls, config_kwargs) + # TODO: segfault + # @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + # def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): + # self._test_generate_half_prec(model_id, config_cls, config_kwargs) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") @@ -658,32 +677,31 @@ def test_quanto_merge_conv2d(self, test_name, model_id, config_cls, config_kwarg if config.is_prompt_learning: pytest.skip("Prompt learning models do not support merging.") - config.target_modules = {"conv2d"} + config.target_modules = {"seq.0", "seq.2", "seq.4"} config.task_type = None class ModelConv2D(nn.Module): def __init__(self): super().__init__() - self.conv2d = nn.Conv2d(5, 10, 3) - self.relu = nn.ReLU() - self.flat = nn.Flatten() - self.lin0 = nn.Linear(10, 2) - self.sm = nn.LogSoftmax(dim=-1) + self.seq = nn.Sequential( + nn.Conv2d(3, 8, 3), + nn.ReLU(), + nn.Conv2d(8, 8, 3), + nn.ReLU(), + nn.Conv2d(8, 8, 3), + nn.ReLU(), + nn.Flatten(), + nn.Linear(800, 64), + ) def forward(self, X): - X = X.float().reshape(-1, 5, 3, 3) - X = self.conv2d(X) - X = self.relu(X) - X = self.flat(X) - X = self.lin0(X) - X = self.sm(X) - return X + return self.seq(X) model = ModelConv2D() model = get_peft_model(model, config) model = model.to(self.torch_device) - dummy_input = torch.arange(90).view(9, 10).to(self.torch_device) + dummy_input = torch.randn(5, 3, 16, 16).to(self.torch_device) model.eval() logits = model(dummy_input)[0] @@ -695,20 +713,25 @@ def forward(self, X): model = model.merge_and_unload() logits_merged_unloaded = model(dummy_input)[0] - cc_matrix = self.get_correlation_matrix(logits, logits_merged, logits_unmerged, logits_merged_unloaded) - assert cc_matrix.min() > self.min_correlation + self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) class PeftQuanto2bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): - transformers_class = make_automodel_proxy(weights="int2") + weights = "int2" + transformers_class = make_automodel_proxy(weights=weights) min_correlation = 0.9 + tol = 0.3 class PeftQuanto4bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): - transformers_class = make_automodel_proxy(weights="int4") + weights = "int4" + transformers_class = make_automodel_proxy(weights=weights) min_correlation = 0.95 + tol = 1e-2 class PeftQuanto8bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): - transformers_class = make_automodel_proxy(weights="int8") + weights = "int8" + transformers_class = make_automodel_proxy(weights=weights) min_correlation = 0.95 + tol = 1e-2 From 4b02c8a098331c032cf882999a5e0053c8f4ccb9 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 6 Sep 2024 15:07:59 +0200 Subject: [PATCH 09/23] Better transformers "emulation" --- tests/test_quanto.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index a858588c23..0a26d783de 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -68,6 +68,14 @@ def make_automodel_proxy(weights: str): from optimum.quanto import QuantizedModelForCausalLM, qint2, qint4, qint8 from transformers.utils.quantization_config import QuantizationMethod + # dummy objects to imitate transformers + class QuantizationConfig: + quant_method = QuantizationMethod.QUANTO + + class HfQuantizer: + is_trainable = False + quantization_config = QuantizationConfig() + class QuantoModelProxy: @classmethod def from_pretrained(self, *args, **kwargs): @@ -82,7 +90,8 @@ def from_pretrained(self, *args, **kwargs): # float8 was tried but was producing NaNs raise ValueError(f"Invalid quantization dtype for quanto: {weights}") - model.quantization_method = QuantizationMethod.QUANTO + model.is_quantized = True + model.hf_quantizer = HfQuantizer() return model return QuantoModelProxy From 573583f82afd6392658fd7996f6fc7f15eab0d3f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 28 Oct 2024 16:01:45 +0100 Subject: [PATCH 10/23] Rework tests to use QuantoConfig --- docs/source/developer_guides/quantization.md | 14 +++---- tests/test_quanto.py | 40 ++++---------------- 2 files changed, 13 insertions(+), 41 deletions(-) diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index efc3c80c6c..fc9fd2e345 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -216,23 +216,21 @@ model = get_peft_model(base_model, peft_config) PEFT supports models quantized with [optimum-quanto](https://github.com/huggingface/optimum-quanto). This has been tested with 2bit, 4bit, and 8bit int quantization. Optimum-quanto also works on CPU and MPS. ```python -from transformers import AutoModelForCausalLM -from optimum.quanto import QuantizedModelForCausalLM, qint2, qint4, qint8 +from transformers import AutoModelForCausalLM, QuantoConfig model_id = ... -base_model = AutoModelForCausalLM.from_pretrained(model_id) -QuantizedModelForCausalLM.quantize(model, weights=qint4) # or qint2 or qint8 +quantization_config = QuantoConfig(weights="int4") # or qint2 or qint8 +base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) peft_config = LoraConfig(...) model = get_peft_model(base_model, peft_config) ``` - - ### Caveats: -- Use a version > 2.4.0, otherwise saving and loading won't work properly. +- Use optimum-quanto v0.2.5 or above, otherwise saving and loading won't work properly. +- If you want to use optimum-quanto via transformers, install transformers v4.46.0 or above. - Float8 is discouraged as it can easily produce NaNs. -- There is explicit support for optimum-quanto when used with LoRA. However, when optimum-quanto quantizes a layer, it remains a subclass of the corresponding torch class (e.g., quanto's `QLinear` is a subclass of `nn.Linear`). For this reason, methods will generally also work with optimum-quanto, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA**. If you use a method other than LoRA, merging may not raise an error but the results will be incorrect. +- There is explicit support for optimum-quanto when used with LoRA. However, when optimum-quanto quantizes a layer, it remains a subclass of the corresponding torch class (e.g., quanto's `QLinear` is a subclass of `nn.Linear`). For this reason, non-LoRA methods will generally also work with optimum-quanto, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA**. If you use a method other than LoRA, merging may not raise an error but the results will be incorrect. ## Other Supported PEFT Methods diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 0a26d783de..b8d4c8843d 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -56,42 +56,17 @@ def filter_supported_methods_supporting_merging(test_list): def make_automodel_proxy(weights: str): - """Instantiate a quanto-quantized transformers model. - - As quanto is not yet integrated into transformers itself, this is done manually for now but should be replaced once - transformers supports it. - - """ + """Instantiate a quanto-quantized transformers model.""" # TODO: Can't use `from transformers import QuantoConfig` because it checks for the quanto package, but quanto is # now part of optimum, resulting in the check to fail. # Switch to QuantoConfig once https://github.com/huggingface/transformers/pull/31732 is merged - from optimum.quanto import QuantizedModelForCausalLM, qint2, qint4, qint8 - from transformers.utils.quantization_config import QuantizationMethod - - # dummy objects to imitate transformers - class QuantizationConfig: - quant_method = QuantizationMethod.QUANTO - - class HfQuantizer: - is_trainable = False - quantization_config = QuantizationConfig() + from transformers import QuantoConfig class QuantoModelProxy: @classmethod def from_pretrained(self, *args, **kwargs): - model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) - if weights == "int2": - QuantizedModelForCausalLM.quantize(model, weights=qint2) - elif weights == "int4": - QuantizedModelForCausalLM.quantize(model, weights=qint4) - elif weights == "int8": - QuantizedModelForCausalLM.quantize(model, weights=qint8) - else: - # float8 was tried but was producing NaNs - raise ValueError(f"Invalid quantization dtype for quanto: {weights}") - - model.is_quantized = True - model.hf_quantizer = HfQuantizer() + quantization_config = QuantoConfig(weights=weights) + model = AutoModelForCausalLM.from_pretrained(*args, quantization_config=quantization_config, **kwargs) return model return QuantoModelProxy @@ -509,10 +484,9 @@ def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) - # TODO: segfault - # @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - # def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): - # self._test_generate_half_prec(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate_half_prec(model_id, config_cls, config_kwargs) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") From 252e045a982545ef37a9cdd3e363d39c00035f6f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 28 Oct 2024 16:36:56 +0100 Subject: [PATCH 11/23] Enable mixed batch inference for Linear There is an issue with quanto not working with torch.inference_mode that the test needs to work around. --- src/peft/tuners/lora/quanto.py | 50 ++++++++++++++++++++++++++++++---- tests/test_quanto.py | 23 ++++++++-------- tests/testing_common.py | 20 +++++++++++--- 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index 1dda88dd18..d0e08ad12b 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -60,22 +60,60 @@ def __init__( self._active_adapter = adapter_name self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + output = lora_B(lora_A(dropout(sub_batch))) * scaling + if requires_conversion: + output = output.to(expected_dtype) + result[sub_batch_indices_list[i]] += output + + return result + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: - result = self.base_layer(x) + self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) - if adapter_names is not None: - raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") - - if self.disable_adapters: - return result if self.disable_adapters: if self.merged: self.unmerge() result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: + result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue diff --git a/tests/test_quanto.py b/tests/test_quanto.py index b8d4c8843d..26fc37ae8f 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -455,18 +455,17 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] self.check_tensors_approximately_equal(logits, logits_merged_from_pretrained) - # TODO: enable if/when mixed batch inference is supported - # @parameterized.expand( - # PeftTestConfigManager.get_grid_parameters( - # { - # "model_ids": PEFT_DECODER_MODELS_TO_TEST, - # "lora_kwargs": {"init_lora_weights": [False]}, - # "task_type": "CAUSAL_LM", - # }, - # ) - # ) - # def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): - # self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): + self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate(self, test_name, model_id, config_cls, config_kwargs): diff --git a/tests/testing_common.py b/tests/testing_common.py index bdd9bd301d..2441d4e0ef 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -20,6 +20,7 @@ import tempfile import warnings from collections import OrderedDict +from contextlib import nullcontext from dataclasses import replace import pytest @@ -868,6 +869,8 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): if config_cls not in (LoraConfig,): return pytest.skip(f"Mixed adapter batches not supported for {config_cls}") + from transformers.quantizers.quantizer_quanto import QuantoHfQuantizer + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -883,18 +886,27 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): # ensure that we have at least 3 samples for this test dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()} - with torch.inference_mode(): + # Using quanto with inference model raises an error: + # > RuntimeError: Cannot set version_counter for inference tensor + # https://github.com/huggingface/optimum-quanto/issues/304 + # TODO: remove when/if this is fixed + if isinstance(getattr(model, "hf_quantizer", None), QuantoHfQuantizer): + inference_mode = nullcontext + else: + inference_mode = torch.inference_mode + + with inference_mode(): with model.disable_adapter(): output_base = model(**dummy_input)[0] logits_base = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] model.set_adapter("adapter0") - with torch.inference_mode(): + with inference_mode(): output_adapter0 = model(**dummy_input)[0] logits_adapter0 = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] model.set_adapter("adapter1") - with torch.inference_mode(): + with inference_mode(): output_adapter1 = model(**dummy_input)[0] logits_adapter1 = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] @@ -913,7 +925,7 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): adapters = ["__base__", "adapter0", "adapter1"] dummy_input["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))] - with torch.inference_mode(): + with inference_mode(): output_mixed = model(**dummy_input)[0] logits_mixed = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] From 2773b174e515170ba5ea52a4d9959f026a0048e3 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 28 Oct 2024 16:39:12 +0100 Subject: [PATCH 12/23] Remove obsolete comment --- tests/test_quanto.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 26fc37ae8f..41a28a6200 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -57,9 +57,6 @@ def filter_supported_methods_supporting_merging(test_list): def make_automodel_proxy(weights: str): """Instantiate a quanto-quantized transformers model.""" - # TODO: Can't use `from transformers import QuantoConfig` because it checks for the quanto package, but quanto is - # now part of optimum, resulting in the check to fail. - # Switch to QuantoConfig once https://github.com/huggingface/transformers/pull/31732 is merged from transformers import QuantoConfig class QuantoModelProxy: From f240c1c07f89628bf36735442bbe9660b351a849 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Oct 2024 13:11:10 +0100 Subject: [PATCH 13/23] Refactor merging to make tests pass --- src/peft/tuners/lora/quanto.py | 66 ++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index d0e08ad12b..73a2bfc098 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -142,47 +142,59 @@ def get_delta_weight(self, adapter): ) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: - from optimum.quanto import quantize_weight - adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: # no adapter to merge return - base_layer = self.get_base_layer() - orig_weight = base_layer.weight + with torch.no_grad(): + new_module = torch.nn.Linear( + self.in_features, self.out_features, device=self.lora_A[adapter_names[0]].weight.device + ) + new_module.weight.zero_() + new_module.bias.zero_() - for active_adapter in adapter_names: - delta_weight = self.get_delta_weight(active_adapter) - # note: no in-place for safe_merge=False - new_weight_data = orig_weight + delta_weight - if safe_merge and not torch.isfinite(new_weight_data).all(): + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + new_module.weight.data += orig_weight + new_module.bias.data += base_layer.bias + + for active_adapter in adapter_names: + new_module.weight.data += self.get_delta_weight(active_adapter) + + quantized = base_layer.from_module(new_module, weights=base_layer.weight_qtype).qweight + if safe_merge and not torch.isfinite(quantized).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - quantized = quantize_weight(new_weight_data, qtype=base_layer.qweight.qtype, axis=base_layer.qweight.axis) - base_layer.weight._data = quantized._data - base_layer.weight._scale = quantized._scale - self.merged_adapters.append(active_adapter) + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + self.merged_adapters.extend(adapter_names) def unmerge(self) -> None: - from optimum.quanto import quantize_weight - if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - while len(self.merged_adapters) > 0: - active_adapter = self.merged_adapters.pop() - if active_adapter not in self.lora_A.keys(): - continue + with torch.no_grad(): + new_module = torch.nn.Linear( + self.in_features, self.out_features, device=self.lora_A[self.active_adapters[0]].weight.device + ) + new_module.weight.zero_() + new_module.bias.zero_() base_layer = self.get_base_layer() - orig_weight = base_layer.weight - new_weight_data = orig_weight - self.get_delta_weight(active_adapter) - quantized = quantize_weight(new_weight_data, qtype=base_layer.qweight.qtype, axis=base_layer.qweight.axis) - base_layer.weight._data = quantized._data - base_layer.weight._scale = quantized._scale + orig_weight = base_layer.qweight + new_module.weight.data += orig_weight + new_module.bias.data += base_layer.bias + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + new_module.weight.data -= self.get_delta_weight(active_adapter) + + quantized = base_layer.from_module(new_module, weights=base_layer.weight_qtype).qweight + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale def __repr__(self) -> str: rep = super().__repr__() @@ -344,7 +356,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N return base_layer = self.get_base_layer() - orig_weight = base_layer.weight + orig_weight = base_layer.qweight for active_adapter in adapter_names: delta_weight = self.get_delta_weight(active_adapter) @@ -356,8 +368,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) - base_layer.weight._data = quantized._data - base_layer.weight._scale = quantized._scale + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale self.merged_adapters.append(active_adapter) def unmerge(self) -> None: From 63e5cdb1c67dd6801ace1b2414ad6a37bee20570 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Oct 2024 14:07:08 +0100 Subject: [PATCH 14/23] Optimum-quanto import check and install for CI --- setup.py | 1 + src/peft/import_utils.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1659facd0e..af3d3a3dd2 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ "scipy", "protobuf", "sentencepiece", + "optimum-quanto", ] setup( diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 4c7d2b040a..1bb9dafac0 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -107,4 +107,6 @@ def is_torchao_available(): @lru_cache def is_quanto_available(): - return importlib.util.find_spec("optimum.quanto") is not None + return (importlib.util.find_spec("optimum") is not None) and ( + importlib.util.find_spec("optimum.quanto") is not None + ) From 3862b41727cf5283dd409150c3cc9a46170bac3b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Oct 2024 14:27:39 +0100 Subject: [PATCH 15/23] Fix import check --- src/peft/tuners/lora/quanto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index 73a2bfc098..0cd1e08f59 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -27,7 +27,7 @@ from peft.utils.other import transpose -if is_quanto_available: +if is_quanto_available(): # ensure that there are no quanto imports unless optimum.quanto is installed from optimum.quanto import QConv2d, QLinear else: From 85d096f4c0a1b38c8f8340d03db35c1f9410db94 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Oct 2024 15:11:01 +0100 Subject: [PATCH 16/23] Apply test filter where appropriate --- tests/test_quanto.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 41a28a6200..462e60ef11 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -42,11 +42,9 @@ def filter_supported_methods_supporting_merging(test_list): return [test for test in test_list if any(test[2] is cls for cls in PEFT_METHODS_SUPPORTING_MERGING)] +# only test a single model, it's already slow as is PEFT_DECODER_MODELS_TO_TEST = [ "hf-internal-testing/tiny-random-OPTForCausalLM", - # "hf-internal-testing/tiny-random-GPT2LMHeadModel", - # "trl-internal-testing/tiny-random-LlamaForCausalLM", - # "peft-internal-testing/tiny-dummy-qwen2", ] FULL_GRID = { @@ -459,6 +457,7 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co "lora_kwargs": {"init_lora_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=filter_supported_methods_supporting_merging, ) ) def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): @@ -534,6 +533,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=filter_supported_methods_supporting_merging, ) ) def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -570,6 +570,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=filter_supported_methods_supporting_merging, ) ) def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): From 1538cac5393cda0420cb23e772b8ad8645d1f362 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 29 Oct 2024 15:49:59 +0100 Subject: [PATCH 17/23] Skip MacOS, comment a segfaulting test --- tests/test_quanto.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 462e60ef11..0bd6146fb7 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import platform import shutil import tempfile import unittest @@ -67,6 +68,7 @@ def from_pretrained(self, *args, **kwargs): return QuantoModelProxy +@unittest.skipIf(platform.system() == "Darwin", "Tests are skipped on macOS") class BasePeftQuantoModelTester: r"""Base class implementing tests for quanto-quantized models. @@ -479,9 +481,10 @@ def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): - self._test_generate_half_prec(model_id, config_cls, config_kwargs) + # this fails for a couple of methods (IA³, LoRA, prefix tuning) with segfault on GH CI + # @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + # def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): + # self._test_generate_half_prec(model_id, config_cls, config_kwargs) @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") From c86cee0d06493bde4403794507e449dd51f5119d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 10 Jan 2025 15:03:35 +0100 Subject: [PATCH 18/23] Some fixes for quanto + hf_device_map --- src/peft/utils/integrations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index 5be6300c0f..0b24b5f2aa 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -128,15 +128,22 @@ def get_layer_device_map(model): """ Derive the device map for the layers of the model. """ - main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] + if not hasattr(model, "hf_device_map"): + return None + + if (len(model.hf_device_map) == 1) and hasattr(model, "device"): + # E.g. with quanto, when the model is loaded as: + # `model = AutoModel.from_pretrained(model_id, quantization_config=quanto_config)` + # Then the model.hf_device_map is set to {'': 'cpu'}, even if model.to(0) is called later. Thus we can't fully + # rely on the hf_device_map. + main_device = model.device + else: + main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] execution_device_map = { name: main_device if device in ["cpu", "disk"] else device for name, device in model.hf_device_map.items() } - if execution_device_map is None: - return None - if len(execution_device_map) == 1 and "" in execution_device_map: return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)} @@ -166,6 +173,9 @@ def map_cache_to_layer_device_map(model, cache) -> None: return layer_device_map = get_layer_device_map(model) + if layer_device_map is None: + return + for idx in range(model.config.num_hidden_layers): layer_device = layer_device_map[idx] cache.key_cache[idx] = cache.key_cache[idx].to(layer_device) From c28046cfc048abd3f103db979b1a49cdbb95d23d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 18 Jun 2025 18:22:29 +0200 Subject: [PATCH 19/23] Update test_quanto to use pytest, fix small bug - Refactor from unittest to pytest - Fix a small bug with merging when there is no bias --- src/peft/tuners/lora/quanto.py | 6 +- tests/test_quanto.py | 363 ++++++++++++++------------------- 2 files changed, 157 insertions(+), 212 deletions(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index 0cd1e08f59..430ba9aaab 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -157,7 +157,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N base_layer = self.get_base_layer() orig_weight = base_layer.qweight new_module.weight.data += orig_weight - new_module.bias.data += base_layer.bias + if getattr(base_layer, "bias", None) is not None: + new_module.bias.data += base_layer.bias for active_adapter in adapter_names: new_module.weight.data += self.get_delta_weight(active_adapter) @@ -186,7 +187,8 @@ def unmerge(self) -> None: base_layer = self.get_base_layer() orig_weight = base_layer.qweight new_module.weight.data += orig_weight - new_module.bias.data += base_layer.bias + if getattr(base_layer, "bias", None) is not None: + new_module.bias.data += base_layer.bias while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 5d10a8070a..39651f1d81 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -20,7 +20,6 @@ import pytest import torch -from parameterized import parameterized from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer @@ -29,29 +28,50 @@ PrefixTuningConfig, PromptTuningConfig, PromptTuningInit, + TaskType, get_peft_model, ) -from .testing_common import PeftCommonTester, PeftTestConfigManager +from .testing_common import PeftCommonTester -# only the PEFT methods that are explicitly supported will be tested for merging -PEFT_METHODS_SUPPORTING_MERGING = [LoraConfig] - - -def filter_supported_methods_supporting_merging(test_list): - return [test for test in test_list if any(test[2] is cls for cls in PEFT_METHODS_SUPPORTING_MERGING)] +MODELS_TO_TEST = [ + "trl-internal-testing/tiny-random-LlamaForCausalLM", +] -# only test a single model, it's already slow as is -PEFT_DECODER_MODELS_TO_TEST = [ - "hf-internal-testing/tiny-random-OPTForCausalLM", +ALL_CONFIGS = [ + ( + LoraConfig, + { + "r": 8, + "lora_alpha": 32, + "target_modules": None, + "lora_dropout": 0.05, + "bias": "none", + "task_type": TaskType.CAUSAL_LM, + }, + ), + ( + PrefixTuningConfig, + { + "num_virtual_tokens": 10, + "task_type": TaskType.CAUSAL_LM, + }, + ), + ( + PromptTuningConfig, + { + "num_virtual_tokens": 10, + "task_type": TaskType.CAUSAL_LM, + }, + ), ] -FULL_GRID = { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "task_type": "CAUSAL_LM", -} + +def _skip_if_merging_not_supported(model_id, config_cls): + if config_cls in (PrefixTuningConfig, PromptTuningConfig): + pytest.skip("This PEFT method does not support merging") def make_automodel_proxy(weights: str): @@ -88,6 +108,10 @@ class BasePeftQuantoModelTester: # the allowed tolerance for comparing the output tensors tol = "MISSING" + def skipTest(self, reason=""): + # for backwards compatibility with unittest style test classes + pytest.skip(reason) + def _get_correlation_matrix(self, *tensors): return torch.corrcoef(torch.stack([t.flatten() for t in tensors])) @@ -118,20 +142,24 @@ def prepare_inputs_for_testing(self): return input_dict - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_attributes_parametrized(self, model_id, config_cls, config_kwargs): self._test_model_attr(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_adapter_name(self, model_id, config_cls, config_kwargs): self._test_adapter_name(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_prepare_for_training_parametrized(self, model_id, config_cls, config_kwargs): self._test_prepare_for_training(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_prompt_tuning_text_prepare_for_training(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_prompt_tuning_text_prepare_for_training(self, model_id, config_cls, config_kwargs): # Test that prompt tuning works with text init if config_cls != PromptTuningConfig: return pytest.skip(f"This test does not apply to {config_cls}") @@ -184,45 +212,37 @@ def test_prompt_tuning_config_invalid_args(self): tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, ) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained(self, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained_pickle(self, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained_selected_adapters_pickle(self, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_from_pretrained_config_construction(self, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "adalora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "vera_kwargs": {"init_weights": [False]}, - "fourierft_kwargs": {"init_weights": [False]}, - "hra_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) - def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_merge_layers(self, model_id, config_cls, config_kwargs): # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a # custom check that relies on correlation and outlier removal + _skip_if_merging_not_supported(model_id, config_cls) torch.manual_seed(0) config = config_cls( @@ -250,27 +270,15 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "vera_kwargs": {"init_weights": [False]}, - "fourierft_kwargs": {"init_weights": [False]}, - "hra_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) # TODO: enable if/when deepcopy-ing is supported @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") - def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): + def test_merge_layers_multi(self, model_id, config_cls, config_kwargs): # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a # custom check that relies on correlation and outlier removal # NOTE: don't use with `torch.inference_mode()`, see: https://github.com/huggingface/optimum-quanto/issues/304 + _skip_if_merging_not_supported(model_id, config_cls) torch.manual_seed(0) config = config_cls( @@ -324,21 +332,12 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs self.check_tensors_approximately_equal(logits_adapter_1, logits_merged_adapter_default) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) - def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_merge_layers_nan(self, model_id, config_cls, config_kwargs): # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a # custom check that relies on correlation and outlier removal + _skip_if_merging_not_supported(model_id, config_cls) torch.manual_seed(0) config = config_cls( @@ -393,24 +392,10 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): ): model = model.merge_and_unload(safe_merge=True) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "adalora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "vera_kwargs": {"init_weights": [False]}, - "fourierft_kwargs": {"init_weights": [False]}, - "hra_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) @pytest.mark.xfail(strict=True) - def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, config_kwargs): + def test_load_merge_and_unloaded_model(self, model_id, config_cls, config_kwargs): # Saving and loading a quanto model that has been merged and unloaded does not work correctly. Here is the # reason: Quanto requires its own save_pretrained method, which, among others, saves the quantization map. # Without it, the model cannot be correctly loaded. To make use of this, we should thus use a quanto @@ -420,6 +405,7 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co # save_pretrained, thus loading the merged and unloaded model does not work. from optimum.quanto import QuantizedModelForCausalLM + _skip_if_merging_not_supported(model_id, config_cls) torch.manual_seed(0) model = self.transformers_class.from_pretrained(model_id) @@ -452,136 +438,108 @@ def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, co logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] self.check_tensors_approximately_equal(logits, logits_merged_from_pretrained) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) - def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_generate(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_generate(self, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_generate_pos_args(self, model_id, config_cls, config_kwargs): # positional args are supported for PeftModelForCausalLM self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters(FULL_GRID), - filter_params_func=filter_supported_methods_supporting_merging, - ) - def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) # this fails for a couple of methods (IA³, LoRA, prefix tuning) with segfault on GH CI - # @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - # def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): - # self._test_generate_half_prec(model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_generate_half_prec(self, model_id, config_cls, config_kwargs): + self._test_generate_half_prec(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") - def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): + def test_prefix_tuning_half_prec_conversion(self, model_id, config_cls, config_kwargs): self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_decoders(self, model_id, config_cls, config_kwargs): self._test_training(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwargs): self._test_training_layer_indexing(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs): self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_inference_safetensors(self, model_id, config_cls, config_kwargs): self._test_inference_safetensors(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_peft_model_device_map(self, model_id, config_cls, config_kwargs): self._test_peft_model_device_map(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_delete_adapter(self, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "adalora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "vera_kwargs": {"init_weights": [False]}, - "fourierft_kwargs": {"init_weights": [False]}, - "hra_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) - def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_unload_adapter(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) self._test_unload_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - ) - ) - def test_weighted_combination_of_adapters(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) self._test_weighted_combination_of_adapters(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwargs): self._test_training_prompt_learning_tasks(model_id, config_cls, config_kwargs) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "adalora_kwargs": {"init_lora_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "vera_kwargs": {"init_weights": [False]}, - "fourierft_kwargs": {"init_weights": [False]}, - "hra_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) - def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_disable_adapter(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) self._test_disable_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) - def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): - self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_passing_input_embeds_works(self, model_id, config_cls, config_kwargs): + self._test_passing_input_embeds_works(self, model_id, config_cls, config_kwargs) # TODO: enable if/when deepcopy-ing is supported @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") @@ -634,31 +592,16 @@ def test_prompt_learning_with_grouped_query_attention(self): # does not raise model(x) - @parameterized.expand( - PeftTestConfigManager.get_grid_parameters( - { - "model_ids": PEFT_DECODER_MODELS_TO_TEST, - "lora_kwargs": {"init_lora_weights": [False]}, - "adalora_kwargs": {"init_lora_weights": [False]}, - "ia3_kwargs": {"init_ia3_weights": [False]}, - "boft_kwargs": {"init_weights": [False]}, - "vera_kwargs": {"init_weights": [False]}, - "fourierft_kwargs": {"init_weights": [False]}, - "hra_kwargs": {"init_weights": [False]}, - "task_type": "CAUSAL_LM", - }, - filter_params_func=filter_supported_methods_supporting_merging, - ) - ) - def test_quanto_merge_conv2d(self, test_name, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_quanto_merge_conv2d(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) torch.manual_seed(0) config = config_cls( base_model_name_or_path=model_id, **config_kwargs, ) - if config.is_prompt_learning: - pytest.skip("Prompt learning models do not support merging.") config.target_modules = {"seq.0", "seq.2", "seq.4"} config.task_type = None @@ -699,21 +642,21 @@ def forward(self, X): self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) -class PeftQuanto2bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): +class TestPeftQuanto2bitModel(PeftCommonTester, BasePeftQuantoModelTester): weights = "int2" transformers_class = make_automodel_proxy(weights=weights) min_correlation = 0.9 tol = 0.3 -class PeftQuanto4bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): +class TestPeftQuanto4bitModel(PeftCommonTester, BasePeftQuantoModelTester): weights = "int4" transformers_class = make_automodel_proxy(weights=weights) min_correlation = 0.95 tol = 1e-2 -class PeftQuanto8bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): +class TestPeftQuanto8bitModel(PeftCommonTester, BasePeftQuantoModelTester): weights = "int8" transformers_class = make_automodel_proxy(weights=weights) min_correlation = 0.95 From 701ded9ee75b7d91cfde3197f063d28d35b6050c Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 19 Jun 2025 11:41:54 +0200 Subject: [PATCH 20/23] Fix some testing issues, skip Windows Unfortunately, Windows tests fail (again?) with: > ImportError: DLL load failed while importing quanto_cpp: The specified module could not be found. --- src/peft/tuners/lora/quanto.py | 17 ++++++----------- tests/test_quanto.py | 11 ++++++----- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py index 430ba9aaab..f7eb1c5047 100644 --- a/src/peft/tuners/lora/quanto.py +++ b/src/peft/tuners/lora/quanto.py @@ -66,6 +66,9 @@ def _mixed_batch_forward( # This is a special method that handles the case when users pass the argument `adapter_names`. This is an # extra argument that allows mixing different adapters in the same batch at inference time. result = self.base_layer(x, *args, **kwargs) + # some quanto quantizations may require cloning or else will fail later when assigning the lora output in-place + result = result.clone() + torch_result_dtype = result.dtype unique_adapters = set(adapter_names) sub_batch_indices_list = [] @@ -82,21 +85,13 @@ def _mixed_batch_forward( lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear # layer output sub_batch = x[sub_batch_indices_list[i]] - output = lora_B(lora_A(dropout(sub_batch))) * scaling - if requires_conversion: - output = output.to(expected_dtype) - result[sub_batch_indices_list[i]] += output + lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling + result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype) return result diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 39651f1d81..d09483ce4e 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -33,8 +33,10 @@ ) from .testing_common import PeftCommonTester +from .testing_utils import set_init_weights_false +# only test a small subset of models and PEFT methods, testing exhaustively would be slow for little benefit MODELS_TO_TEST = [ "trl-internal-testing/tiny-random-LlamaForCausalLM", ] @@ -88,7 +90,8 @@ def from_pretrained(self, *args, **kwargs): return QuantoModelProxy -@unittest.skipIf(platform.system() == "Darwin", "Tests are skipped on macOS") +# Seeing issues on CI with MacOS and Windows, so skipping them for now +@unittest.skipIf(platform.system() != "Linux", "Tests are skipped on macOS and Windows") class BasePeftQuantoModelTester: r"""Base class implementing tests for quanto-quantized models. @@ -285,8 +288,6 @@ def test_merge_layers_multi(self, model_id, config_cls, config_kwargs): base_model_name_or_path=model_id, **config_kwargs, ) - if config.is_prompt_learning: - pytest.skip("Prompt learning models do not support merging.") model = self.transformers_class.from_pretrained(model_id) model = get_peft_model(model, config) @@ -344,8 +345,6 @@ def test_merge_layers_nan(self, model_id, config_cls, config_kwargs): base_model_name_or_path=model_id, **config_kwargs, ) - if config.is_prompt_learning: - pytest.skip("Prompt learning models do not support merging.") model = self.transformers_class.from_pretrained(model_id) model = get_peft_model(model, config) @@ -442,6 +441,7 @@ def test_load_merge_and_unloaded_model(self, model_id, config_cls, config_kwargs @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) def test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): _skip_if_merging_not_supported(model_id, config_cls) + config_kwargs = set_init_weights_false(config_cls, config_kwargs) self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) @pytest.mark.parametrize("model_id", MODELS_TO_TEST) @@ -517,6 +517,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, c @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) def test_unload_adapter(self, model_id, config_cls, config_kwargs): _skip_if_merging_not_supported(model_id, config_cls) + config_kwargs = set_init_weights_false(config_cls, config_kwargs) self._test_unload_adapter(model_id, config_cls, config_kwargs) @pytest.mark.parametrize("model_id", MODELS_TO_TEST) From 9773da2b693050b577064264a62ab2afd85b233b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 19 Jun 2025 13:04:28 +0200 Subject: [PATCH 21/23] Fix skip marker in test --- tests/test_quanto.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index d09483ce4e..50cb66d659 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -15,7 +15,6 @@ import platform import shutil import tempfile -import unittest from unittest.mock import Mock, call, patch import pytest @@ -91,7 +90,7 @@ def from_pretrained(self, *args, **kwargs): # Seeing issues on CI with MacOS and Windows, so skipping them for now -@unittest.skipIf(platform.system() != "Linux", "Tests are skipped on macOS and Windows") +@pytest.mark.skipif(platform.system() != "Linux", "Tests are skipped on macOS and Windows") class BasePeftQuantoModelTester: r"""Base class implementing tests for quanto-quantized models. From d90b5416b76bafe3ae7d4b09177aa4ff5dbe8091 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 19 Jun 2025 13:57:42 +0200 Subject: [PATCH 22/23] Fix skipif call --- tests/test_quanto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 50cb66d659..0e0c2fc249 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -90,7 +90,7 @@ def from_pretrained(self, *args, **kwargs): # Seeing issues on CI with MacOS and Windows, so skipping them for now -@pytest.mark.skipif(platform.system() != "Linux", "Tests are skipped on macOS and Windows") +@pytest.mark.skipif(platform.system() != "Linux", reason="Tests are skipped on macOS and Windows") class BasePeftQuantoModelTester: r"""Base class implementing tests for quanto-quantized models. From 2eab7f4bdb2bfe38f51909515dbc7b62896f5559 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 20 Aug 2025 12:44:30 +0200 Subject: [PATCH 23/23] Fixing some tests --- tests/test_quanto.py | 1 + tests/testing_common.py | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/test_quanto.py b/tests/test_quanto.py index 0e0c2fc249..3eee78ecca 100644 --- a/tests/test_quanto.py +++ b/tests/test_quanto.py @@ -534,6 +534,7 @@ def test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwarg @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) def test_disable_adapter(self, model_id, config_cls, config_kwargs): _skip_if_merging_not_supported(model_id, config_cls) + config_kwargs = set_init_weights_false(config_cls, config_kwargs) self._test_disable_adapter(model_id, config_cls, config_kwargs) @pytest.mark.parametrize("model_id", MODELS_TO_TEST) diff --git a/tests/testing_common.py b/tests/testing_common.py index 8625ea8688..3c4af3b54e 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1572,11 +1572,24 @@ def _test_delete_unknown_adapter_raises(self, model_id, config_cls, config_kwarg model.delete_adapter("unknown-adapter") def _test_unload_adapter(self, model_id, config_cls, config_kwargs): + from transformers.quantizers.quantizer_quanto import QuantoHfQuantizer + with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + model = model.to(self.torch_device) + + # Using quanto with inference model raises an error: + # > RuntimeError: Cannot set version_counter for inference tensor + # https://github.com/huggingface/optimum-quanto/issues/304 + # TODO: remove when/if this is fixed + if isinstance(getattr(model, "hf_quantizer", None), QuantoHfQuantizer): + inference_mode = nullcontext + else: + inference_mode = torch.inference_mode + num_params_base = len(model.state_dict()) dummy_input = self.prepare_inputs_for_testing() - with torch.inference_mode(): + with inference_mode(): logits_transformers = model(**dummy_input)[0] config = config_cls( @@ -1584,7 +1597,6 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): **config_kwargs, ) model = get_peft_model(model, config) - model = model.to(self.torch_device) if isinstance(config, PromptLearningConfig): # prompt learning does not support unloading @@ -1592,13 +1604,13 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = model.unload() else: self.perturb_trainable_token_weights_if_used(model, config_kwargs) - with torch.inference_mode(): + with inference_mode(): logits_with_adapter = model(**dummy_input)[0] model.eval() model = model.unload() num_params_unloaded = len(model.state_dict()) - with torch.inference_mode(): + with inference_mode(): logits_unload = model(**dummy_input)[0] # check that PEFT layers are completely removed