From 9577efc9d605a207cb5e2fcda5a81f3526e8a876 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Mon, 7 Oct 2024 18:37:14 -0700 Subject: [PATCH] Initial implementation of PCB merge method --- README.md | 13 +++- mergekit/card.py | 3 + mergekit/merge_methods/__init__.py | 6 +- mergekit/merge_methods/pcb.py | 118 +++++++++++++++++++++++++++++ tests/test_basic_merges.py | 20 +++++ 5 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 mergekit/merge_methods/pcb.py diff --git a/README.md b/README.md index cbf93b12..f97e2574 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,8 @@ A quick overview of the currently supported merge methods: | [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ | | [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ | | [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ | +| [PCB](https://arxiv.org/abs/2410.02396) | `pcb` | ✅ | ✅ | + ### Linear The classic merge method - a simple weighted average. @@ -195,10 +197,20 @@ Parameters: Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES Parameters: same as [Linear](#linear), plus: + - `density` - fraction of weights in differences from the base model to retain - `epsilon` - maximum change in drop probability based on magnitude. Drop probabilities assigned will range from `density - epsilon` to `density + epsilon`. (When selecting values for `density` and `epsilon`, ensure that the range of probabilities falls within 0 to 1) - `lambda` - scaling factor for the final merged delta parameters before merging with the base parameters. +### [PCB](https://arxiv.org/abs/2410.02396) + +PCB is a heuristic approach to determine relative importance of parameters in each task vector. It uses terms for both intra-task and inter-task importance to determine both weighting and sparsification of each parameter. + +Parameters: + +- `density` - fraction of weights in differences from the base model to retain +- `weight` - total weight at which to apply the final combined task vector. + ## LoRA extraction Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned models. @@ -241,7 +253,6 @@ Or download your merge: `!arcee merging download bio-merge` - ## Citation We now have a [paper](https://arxiv.org/abs/2403.13257) you can cite for the MergeKit library: diff --git a/mergekit/card.py b/mergekit/card.py index bf0a2d0a..5247b75a 100644 --- a/mergekit/card.py +++ b/mergekit/card.py @@ -118,6 +118,9 @@ def method_md(merge_method: str) -> str: "dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)", "dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)", "model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)", + "della": "[DELLA](https://arxiv.org/abs/2406.11617)", + "della_linear": "linear [DELLA](https://arxiv.org/abs/2406.11617)", + "pcb": "[PCB](https://arxiv.org/abs/2410.02396)", } return methods.get(merge_method, merge_method) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 007e163e..f704ca6f 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -22,6 +22,7 @@ from mergekit.merge_methods.linear import LinearMerge from mergekit.merge_methods.model_stock import ModelStockMerge from mergekit.merge_methods.passthrough import PassthroughMerge +from mergekit.merge_methods.pcb import PCBMerge from mergekit.merge_methods.slerp import SlerpMerge from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge @@ -77,7 +78,6 @@ def get(method: str) -> MergeMethod: ) elif method == "model_stock": return ModelStockMerge() - elif method == "della": return GeneralizedTaskArithmeticMerge( consensus_method=ConsensusMethod.sum, @@ -85,7 +85,6 @@ def get(method: str) -> MergeMethod: default_normalize=True, default_rescale=True, ) - elif method == "della_linear": return GeneralizedTaskArithmeticMerge( consensus_method=None, @@ -93,6 +92,9 @@ def get(method: str) -> MergeMethod: default_normalize=False, default_rescale=True, ) + elif method == "pcb": + return PCBMerge() + raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/merge_methods/pcb.py b/mergekit/merge_methods/pcb.py new file mode 100644 index 00000000..3db53944 --- /dev/null +++ b/mergekit/merge_methods/pcb.py @@ -0,0 +1,118 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from pydantic import BaseModel +from typing_extensions import Literal + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) +from mergekit.merge_methods.generalized_task_arithmetic import get_task_vectors + + +class PCBMerge(MergeMethod): + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef(name="density", required=True), + ConfigParameterDef(name="weight", required=False, default_value=1.0), + ] + + def make_task( + self, + output_weight: WeightInfo, + tensors: MergeTensorInput, + base_model: Optional[ModelReference], + parameters: ImmutableMap[str, Any], + **kwargs, + ) -> Task[torch.Tensor]: + return PCBMergeTask( + output_weight=output_weight, + tensors=tensors, + base_model=base_model, + density=parameters["density"], + weight=parameters["weight"], + ) + + +class PCBMergeTask(Task[torch.Tensor]): + output_weight: WeightInfo + tensors: MergeTensorInput + base_model: Optional[ModelReference] + density: float + weight: float + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.tensors} + + def execute( + self, + tensors: Dict[ModelReference, torch.Tensor], + **_kwargs, + ) -> torch.Tensor: + # collect task vectors + tv_info, base = get_task_vectors( + self.output_weight, + self.base_model, + tensors, + tensor_parameters=ImmutableMap({model: {} for model in tensors}), + ) + if not tv_info: + return base + + n = len(tv_info) + tvs = torch.stack([tv["delta"] for tv in tv_info], dim=0) + tvs_flat = tvs.view(n, -1) + + # $b_i = b_{intra, i} \odot b_{inter, i}$ + # $b_{intra, i} = Softmax(N \cdot Norm(\delta_i \odot \delta_i))$ + norm_tvs_sqr = F.normalize(tvs_flat * tvs_flat, dim=1) + b_intra = F.softmax(n * norm_tvs_sqr, dim=1) + + # $b_{inter, i} = \sum_{j = 1}^{n} Softmax(Norm(\delta_i \odot \delta_j))$ + b_inter = torch.zeros_like(tvs_flat) + for i in range(n): + inter_prod = tvs_flat[i] * tvs_flat + inter_norm = F.normalize(inter_prod, dim=1) + b_inter[i] = F.softmax(inter_norm, dim=1).sum(dim=0) + + b = b_intra * b_inter + k = int(tvs_flat.shape[1] * self.density) + # $m_i = b_i \geq sorted(b_i)[k]$ + # threshold = torch.kthvalue(b, k).values + # m = (b >= threshold.unsqueeze(1)).float() + _, indices = torch.topk(b, k, dim=1) + m = torch.zeros_like(b) + m.scatter_(1, indices, 1) + + # $\hat{b}_i = b_i \odot m_i$ + b_hat = b * m + + weights = b_hat / torch.sum(b_hat) + final_delta = torch.sum(tvs_flat * weights, dim=0).view(tvs.shape[1:]) + return base + self.weight * final_delta diff --git a/tests/test_basic_merges.py b/tests/test_basic_merges.py index ae54de43..e8ac7ffe 100644 --- a/tests/test_basic_merges.py +++ b/tests/test_basic_merges.py @@ -137,6 +137,26 @@ def test_model_stock_merge(self, model_a, model_b, model_c): ) run_and_check_merge(config) + def test_della_merge(self, model_a, model_b, model_c): + config = self.two_model_config( + model_a, + model_b, + merge_method="della", + base_model=model_c, + params={"density": 0.7, "epsilon": 0.15, "lambda": 1.0}, + ) + run_and_check_merge(config) + + def test_pcb_merge(self, model_a, model_b, model_c): + config = self.two_model_config( + model_a, + model_b, + merge_method="pcb", + base_model=model_c, + params={"density": 0.8, "weight": 1.0}, + ) + run_and_check_merge(config) + def test_model_stock_filterwise_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_b,