From 655ff273e311c0306e90890637d0eaa9d4c8edac Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 14:28:13 +0800 Subject: [PATCH 01/11] Update registry.py Add CABSMerge --- mergekit/merge_methods/registry.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mergekit/merge_methods/registry.py b/mergekit/merge_methods/registry.py index 7b40f4a3..3a9f3b56 100644 --- a/mergekit/merge_methods/registry.py +++ b/mergekit/merge_methods/registry.py @@ -9,6 +9,7 @@ ConsensusMethod, GeneralizedTaskArithmeticMerge, ) +from mergekit.merge_methods.cabs import CABSMerge from mergekit.merge_methods.karcher import KarcherMerge from mergekit.merge_methods.linear import LinearMerge from mergekit.merge_methods.model_stock import ModelStockMerge @@ -25,6 +26,7 @@ ModelStockMerge(), ArceeFusionMerge(), KarcherMerge(), + CABSMerge(), # generalized task arithmetic methods GeneralizedTaskArithmeticMerge( consensus_method=None, From 4c5913979e47336b1b06bd54d14017e9283c6fcb Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 14:31:51 +0800 Subject: [PATCH 02/11] Create cabs.py --- mergekit/merge_methods/cabs.py | 282 +++++++++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 mergekit/merge_methods/cabs.py diff --git a/mergekit/merge_methods/cabs.py b/mergekit/merge_methods/cabs.py new file mode 100644 index 00000000..543a90b9 --- /dev/null +++ b/mergekit/merge_methods/cabs.py @@ -0,0 +1,282 @@ +# mergekit/merge_methods/cabs.py +import logging +import torch +from typing import List, Dict, Tuple, Any, Optional +from pydantic import BaseModel, Field, validator +from typing_extensions import override, Literal + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) + +# --- Helper function for n:m structural pruning --- +def prune_n_m_structural( + tensor: torch.Tensor, + n_val: int, + m_val: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if not isinstance(tensor, torch.Tensor) or tensor.numel() == 0: + return tensor.clone(), torch.zeros_like(tensor, dtype=torch.bool) + original_shape = tensor.shape + device = tensor.device + flat_tensor_orig_values = tensor.flatten().clone() + num_elements = flat_tensor_orig_values.numel() + if m_val <= 0: + logging.error(f"Tensor shape {original_shape}: m_val ({m_val}) must be positive. No pruning.") + return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + if n_val < 0 or n_val > m_val: + logging.error(f"Tensor shape {original_shape}: n_val ({n_val}) invalid. No pruning.") + return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + if n_val == 0: + return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) + if n_val == m_val: + return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + padding = (m_val - (num_elements % m_val)) % m_val + if padding > 0: + flat_tensor_padded = torch.cat( + (flat_tensor_orig_values, torch.zeros(padding, device=device, dtype=tensor.dtype)) + ) + else: + flat_tensor_padded = flat_tensor_orig_values + reshaped_tensor = flat_tensor_padded.reshape(-1, m_val) + if reshaped_tensor.numel() == 0: + return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) + magnitudes = torch.abs(reshaped_tensor) + _, top_n_indices_in_blocks = torch.topk(magnitudes, k=n_val, dim=1) + nm_mask_blocks = torch.zeros_like(reshaped_tensor, dtype=torch.bool, device=device) + nm_mask_blocks.scatter_(1, top_n_indices_in_blocks, True) + nm_mask_flat_padded = nm_mask_blocks.flatten() + if padding > 0: + nm_mask_unpadded = nm_mask_flat_padded[:-padding] + else: + nm_mask_unpadded = nm_mask_flat_padded + final_mask_reshaped = nm_mask_unpadded.reshape(original_shape) + final_pruned_tensor = tensor * final_mask_reshaped + return final_pruned_tensor, final_mask_reshaped + +# --- Mergekit Method Definition --- +class CABSMerge(MergeMethod, BaseModel, frozen=True): + # These fields are part of the method's configuration, settable via YAML. + # Pydantic uses these defaults if not provided in YAML for the 'cabs' method block. + default_n_m_ratio: Optional[Tuple[int, int]] = Field( + default=None, + description="Optional global default [n, m] ratio for n:m pruning. E.g., [1, 4]." + ) + pruning_order: Optional[List[str]] = Field( + default=None, + description="Optional: List of model source names (from YAML 'sources') defining the CA processing order." + ) + + # These are more like fixed properties of the method, not typically changed by user YAML for 'cabs' + # but Pydantic treats them as fields that can be initialized. + # Mergekit's MergeMethod.create will pass YAML params, potentially overriding these if keys match. + # It's safer to have these as fixed return values in name(), pretty_name(), etc. if they are truly static. + # However, to allow Mergekit's create(**kwargs) to work seamlessly if it tries to pass them, + # we keep them as fields with defaults. + method_name_override: Optional[str] = Field(default=None, exclude=True) # For internal use if variants are registered + method_pretty_name_override: Optional[str] = Field(default=None, exclude=True) + method_reference_url_override: Optional[str] = Field(default=None, exclude=True) + + + @validator('default_n_m_ratio', pre=True, always=True) + def check_default_n_m_ratio(cls, v: Any) -> Optional[Tuple[int, int]]: + if v is not None: + if not (isinstance(v, (list, tuple)) and len(v) == 2 and + isinstance(v[0], int) and isinstance(v[1], int) and + 0 <= v[0] <= v[1] and v[1] > 0): + raise ValueError( + "default_n_m_ratio must be a tuple/list of two integers [n, m] " + "with 0 <= n <= m and m > 0, or null." + ) + return tuple(v) + return None + + @override + def name(self) -> str: + return self.method_name_override or "cabs" + + @override + def pretty_name(self) -> Optional[str]: + return self.method_pretty_name_override or "Conflict-Aware and Balanced Sparsification" + + @override + def reference_url(self) -> Optional[str]: + return self.method_reference_url_override or "https://arxiv.org/abs/2503.01874" + + @override + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef( + name="default_n_m_ratio", type="list[int]", required=False, default_value=None, + description="Optional global default [n, m] ratio. Models can override this. Example: [1, 4]" + ), + ConfigParameterDef( + name="pruning_order", type="list[str]", required=False, default_value=None, + description="Optional: List of model source names (from YAML 'sources') defining the CA processing order." + ), + # These are not typically set by users for the primary "cabs" method, but allow for variants if needed. + # ConfigParameterDef(name="method_name_override", type="str", required=False, advanced=True), + # ConfigParameterDef(name="method_pretty_name_override", type="str", required=False, advanced=True), + # ConfigParameterDef(name="method_reference_url_override", type="str", required=False, advanced=True), + ] + + @override + def tensor_parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef( + name="weight", type="float", required=False, default_value=1.0, + description="Scaling coefficient (lambda) for this model's task vector." + ), + ConfigParameterDef( + name="n_m_ratio", type="list[int]", required=False, default_value=None, + description="Per-model [n, m] ratio for n:m pruning. Overrides global default_n_m_ratio. Example: [1, 2]" + ), + ] + + @override + def make_task( + self, + output_weight: WeightInfo, + tensors: MergeTensorInput, + base_model: Optional[ModelReference], + parameters: ImmutableMap[str, Any], # parameters from YAML for THIS method invocation + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], + ) -> Task: + if base_model is None: + logging.error("CABS merge requires a 'base_model' to be specified in the YAML.") + raise ValueError("CABS merge requires a 'base_model'.") + + # 'self' is the instance from STATIC_MERGE_METHODS. + # 'parameters' contains the YAML overrides for default_n_m_ratio, pruning_order, etc. + # We should use 'parameters' to construct the CABSTask or configure the instance. + # Mergekit's MergeMethod.create typically handles creating a new instance with YAML params. + # So, self.default_n_m_ratio etc. on *this specific 'self'* instance will be the final ones. + + return CABSTask( + method_config=self, # 'self' is the correctly configured CABSMerge instance + tensors_input=tensors, + base_model_ref=base_model, + current_weight_info=output_weight, + per_model_tensor_params=tensor_parameters, + ) + +class CABSTask(Task[torch.Tensor]): + method_config: CABSMerge + tensors_input: MergeTensorInput + base_model_ref: ModelReference + current_weight_info: WeightInfo + per_model_tensor_params: ImmutableMap[ModelReference, ImmutableMap[str, Any]] + + FALLBACK_N_M_RATIO: Tuple[int, int] = (1, 4) + + @override + def uses_accelerator(self) -> bool: + return True + + @override + def arguments(self) -> Dict[str, Task]: + return {"tensors_arg": self.tensors_input} + + def _resolve_nm_ratio_for_model(self, model_ref: ModelReference) -> Tuple[int, int]: + current_model_params = self.per_model_tensor_params.get(model_ref, ImmutableMap({})) + per_model_nm_ratio_raw = current_model_params.get("n_m_ratio") + + if per_model_nm_ratio_raw is not None: + if not (isinstance(per_model_nm_ratio_raw, (list, tuple)) and len(per_model_nm_ratio_raw) == 2 and + isinstance(per_model_nm_ratio_raw[0], int) and isinstance(per_model_nm_ratio_raw[1], int) and + 0 <= per_model_nm_ratio_raw[0] <= per_model_nm_ratio_raw[1] and per_model_nm_ratio_raw[1] > 0): + logging.warning(f"Invalid n_m_ratio {per_model_nm_ratio_raw} for model {model_ref.name} " + f"on tensor {self.current_weight_info.name}. " + f"Falling back.") + else: + return int(per_model_nm_ratio_raw[0]), int(per_model_nm_ratio_raw[1]) + + if self.method_config.default_n_m_ratio is not None: # Use from configured instance + return self.method_config.default_n_m_ratio + + logging.warning(f"No n_m_ratio specified for model {model_ref.name} and no global default " + f"for tensor {self.current_weight_info.name}. " + f"Using hardcoded fallback: {self.FALLBACK_N_M_RATIO}.") + return self.FALLBACK_N_M_RATIO + + @override + def execute( + self, + tensors_arg: Dict[ModelReference, torch.Tensor], + **_kwargs, + ) -> torch.Tensor: + if self.base_model_ref not in tensors_arg: + logging.error(f"Base model '{self.base_model_ref.name}' tensor not found for weight '{self.current_weight_info.name}'.") + device_str = self.current_weight_info.device_str() if hasattr(self.current_weight_info, 'device_str') else 'cpu' + dtype_val = self.current_weight_info.dtype if hasattr(self.current_weight_info, 'dtype') else torch.float32 + return torch.empty(0, device=torch.device(device_str), dtype=dtype_val) + + target_device = tensors_arg[self.base_model_ref].device + target_dtype = tensors_arg[self.base_model_ref].dtype + merged_tensor_accumulator = tensors_arg[self.base_model_ref].clone().to(device=target_device, dtype=target_dtype) + + ordered_model_refs_for_ca: List[ModelReference] = [] + model_ref_by_name_map: Dict[str, ModelReference] = { ref.name: ref for ref in tensors_arg.keys() } + + current_pruning_order_names = self.method_config.pruning_order # Get from configured instance + if current_pruning_order_names: + for name in current_pruning_order_names: + if name == self.base_model_ref.name: + continue + if name in model_ref_by_name_map: + ordered_model_refs_for_ca.append(model_ref_by_name_map[name]) + else: + logging.warning(f"Model '{name}' from pruning_order not found among available tensors " + f"for weight '{self.current_weight_info.name}', skipping this entry in order.") + else: + sorted_non_base_names = sorted([ref.name for ref in tensors_arg.keys() if ref != self.base_model_ref]) + for name in sorted_non_base_names: + if name in model_ref_by_name_map: + ordered_model_refs_for_ca.append(model_ref_by_name_map[name]) + + if not ordered_model_refs_for_ca: + logging.info(f"No non-base models to merge for weight '{self.current_weight_info.name}'. " + "Returning base tensor.") + return merged_tensor_accumulator + + cumulative_param_mask = torch.zeros_like(merged_tensor_accumulator, dtype=torch.bool, device=target_device) + + for model_ref_current in ordered_model_refs_for_ca: + if model_ref_current not in tensors_arg: + logging.warning(f"Tensor for model '{model_ref_current.name}' became unavailable during processing " + f"for weight '{self.current_weight_info.name}', skipping.") + continue + + fine_tuned_tensor_val = tensors_arg[model_ref_current].to(device=target_device, dtype=target_dtype) + base_tensor_for_diff = tensors_arg[self.base_model_ref].to(device=target_device, dtype=target_dtype) + + current_model_params_map = self.per_model_tensor_params.get(model_ref_current, ImmutableMap({})) + scaling_coefficient = float(current_model_params_map.get("weight", 1.0)) + n_val_current, m_val_current = self._resolve_nm_ratio_for_model(model_ref_current) + + task_vector_val = fine_tuned_tensor_val - base_tensor_for_diff + available_params_mask = ~cumulative_param_mask + candidate_task_vector = task_vector_val * available_params_mask.to(task_vector_val.dtype) + + pruned_task_vector, newly_retained_mask = prune_n_m_structural( + candidate_task_vector, + n_val_current, + m_val_current + ) + + merged_tensor_accumulator += scaling_coefficient * pruned_task_vector.to(merged_tensor_accumulator.dtype) + cumulative_param_mask = torch.logical_or( + cumulative_param_mask, + newly_retained_mask.to(device=cumulative_param_mask.device) + ) + + return merged_tensor_accumulator + + @override + def group_label(self) -> Optional[str]: + return self.current_weight_info.name From bd5dded542e9bc675d10dc4e23db883b454351c7 Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 14:32:35 +0800 Subject: [PATCH 03/11] Create cabs.yml --- examples/cabs.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 examples/cabs.yml diff --git a/examples/cabs.yml b/examples/cabs.yml new file mode 100644 index 00000000..e6812fd7 --- /dev/null +++ b/examples/cabs.yml @@ -0,0 +1,29 @@ +# cabs_test_config.yaml +base_model: mistral_base +merge_method: cabs # This will pick up the CABSMerge() instance from registry + +# Parameters for the 'cabs' method. These will be passed to Mergekit's +# MergeMethod.create mechanism, which effectively configures the CABSMerge instance for this run. +default_n_m_ratio: [1, 4] +pruning_order: + - zephyr_alpha + - zephyr_beta + +sources: + mistral_base: + path: /home/yangzz/models_test/Mistral-7b-v0.1/ + zephyr_alpha: + path: /home/yangzz/models_test/zephyr-7b-alpha/ + parameters: + weight: 0.6 + n_m_ratio: [1, 2] + zephyr_beta: + path: /home/yangzz/models_test/zephyr-7b-beta/ + parameters: + weight: 0.4 + # n_m_ratio for zephyr_beta will use 'default_n_m_ratio' [1,4] from above + +out_path: ./merged_cabs_zephyr_output +dtype: bfloat16 +low_cpu_memory: true +copy_tokenizer: true From 268065a2034aa8c0635fb7277866c7ddcaff8d97 Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 14:34:44 +0800 Subject: [PATCH 04/11] Update cabs.yml --- examples/cabs.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/cabs.yml b/examples/cabs.yml index e6812fd7..bda2989d 100644 --- a/examples/cabs.yml +++ b/examples/cabs.yml @@ -1,9 +1,6 @@ -# cabs_test_config.yaml -base_model: mistral_base -merge_method: cabs # This will pick up the CABSMerge() instance from registry -# Parameters for the 'cabs' method. These will be passed to Mergekit's -# MergeMethod.create mechanism, which effectively configures the CABSMerge instance for this run. +base_model: mistral_base +merge_method: cabs default_n_m_ratio: [1, 4] pruning_order: - zephyr_alpha From 7ba30cdd940eb8f74f578716168d9e452de913bd Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 14:35:11 +0800 Subject: [PATCH 05/11] Update cabs.yml --- examples/cabs.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/cabs.yml b/examples/cabs.yml index bda2989d..6d6bae25 100644 --- a/examples/cabs.yml +++ b/examples/cabs.yml @@ -13,12 +13,12 @@ sources: path: /home/yangzz/models_test/zephyr-7b-alpha/ parameters: weight: 0.6 - n_m_ratio: [1, 2] + n_m_ratio: [8, 32] zephyr_beta: path: /home/yangzz/models_test/zephyr-7b-beta/ parameters: - weight: 0.4 - # n_m_ratio for zephyr_beta will use 'default_n_m_ratio' [1,4] from above + weight: 0.4 + n_m_ratio: [8, 32] out_path: ./merged_cabs_zephyr_output dtype: bfloat16 From 0934a1629d6f57180eacd5061ffcd383bd42f87f Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 14:35:22 +0800 Subject: [PATCH 06/11] Update cabs.yml --- examples/cabs.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/cabs.yml b/examples/cabs.yml index 6d6bae25..dc73588b 100644 --- a/examples/cabs.yml +++ b/examples/cabs.yml @@ -1,4 +1,3 @@ - base_model: mistral_base merge_method: cabs default_n_m_ratio: [1, 4] From a65b5ac7eda7f4d326aba8e602178bbc50b94b91 Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 15:34:59 +0800 Subject: [PATCH 07/11] Update cabs.py --- mergekit/merge_methods/cabs.py | 382 +++++++++++++++++---------------- 1 file changed, 200 insertions(+), 182 deletions(-) diff --git a/mergekit/merge_methods/cabs.py b/mergekit/merge_methods/cabs.py index 543a90b9..f5f29039 100644 --- a/mergekit/merge_methods/cabs.py +++ b/mergekit/merge_methods/cabs.py @@ -2,7 +2,7 @@ import logging import torch from typing import List, Dict, Tuple, Any, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, validator, root_validator # Keep BaseModel for CABSMerge for now if registry expects instance from typing_extensions import override, Literal from mergekit.architecture import WeightInfo @@ -14,7 +14,7 @@ MergeTensorInput, ) -# --- Helper function for n:m structural pruning --- +# --- Helper function for n:m structural pruning (remains the same) --- def prune_n_m_structural( tensor: torch.Tensor, n_val: int, @@ -22,120 +22,89 @@ def prune_n_m_structural( ) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(tensor, torch.Tensor) or tensor.numel() == 0: return tensor.clone(), torch.zeros_like(tensor, dtype=torch.bool) - original_shape = tensor.shape - device = tensor.device - flat_tensor_orig_values = tensor.flatten().clone() - num_elements = flat_tensor_orig_values.numel() - if m_val <= 0: - logging.error(f"Tensor shape {original_shape}: m_val ({m_val}) must be positive. No pruning.") - return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) - if n_val < 0 or n_val > m_val: - logging.error(f"Tensor shape {original_shape}: n_val ({n_val}) invalid. No pruning.") - return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) - if n_val == 0: - return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) - if n_val == m_val: - return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) - padding = (m_val - (num_elements % m_val)) % m_val - if padding > 0: - flat_tensor_padded = torch.cat( - (flat_tensor_orig_values, torch.zeros(padding, device=device, dtype=tensor.dtype)) - ) - else: - flat_tensor_padded = flat_tensor_orig_values + original_shape = tensor.shape; device = tensor.device + flat_tensor_orig_values = tensor.flatten().clone(); num_elements = flat_tensor_orig_values.numel() + if m_val <= 0: logging.error(f"Tensor {original_shape}: m_val ({m_val}) must be positive."); return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + if n_val < 0 or n_val > m_val: logging.error(f"Tensor {original_shape}: n_val ({n_val}) invalid."); return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + if n_val == 0: return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) + if n_val == m_val: return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + padding = (m_val-(num_elements % m_val))%m_val + if padding > 0: flat_tensor_padded = torch.cat((flat_tensor_orig_values, torch.zeros(padding,device=device,dtype=tensor.dtype))) + else: flat_tensor_padded = flat_tensor_orig_values reshaped_tensor = flat_tensor_padded.reshape(-1, m_val) - if reshaped_tensor.numel() == 0: - return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) - magnitudes = torch.abs(reshaped_tensor) - _, top_n_indices_in_blocks = torch.topk(magnitudes, k=n_val, dim=1) - nm_mask_blocks = torch.zeros_like(reshaped_tensor, dtype=torch.bool, device=device) - nm_mask_blocks.scatter_(1, top_n_indices_in_blocks, True) + if reshaped_tensor.numel()==0: return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) + magnitudes = torch.abs(reshaped_tensor); _,top_n_indices_in_blocks = torch.topk(magnitudes,k=n_val,dim=1) + nm_mask_blocks = torch.zeros_like(reshaped_tensor,dtype=torch.bool,device=device); nm_mask_blocks.scatter_(1,top_n_indices_in_blocks,True) nm_mask_flat_padded = nm_mask_blocks.flatten() - if padding > 0: - nm_mask_unpadded = nm_mask_flat_padded[:-padding] - else: - nm_mask_unpadded = nm_mask_flat_padded - final_mask_reshaped = nm_mask_unpadded.reshape(original_shape) - final_pruned_tensor = tensor * final_mask_reshaped + if padding > 0: nm_mask_unpadded = nm_mask_flat_padded[:-padding] + else: nm_mask_unpadded = nm_mask_flat_padded + final_mask_reshaped = nm_mask_unpadded.reshape(original_shape); final_pruned_tensor = tensor * final_mask_reshaped return final_pruned_tensor, final_mask_reshaped # --- Mergekit Method Definition --- -class CABSMerge(MergeMethod, BaseModel, frozen=True): - # These fields are part of the method's configuration, settable via YAML. - # Pydantic uses these defaults if not provided in YAML for the 'cabs' method block. - default_n_m_ratio: Optional[Tuple[int, int]] = Field( - default=None, - description="Optional global default [n, m] ratio for n:m pruning. E.g., [1, 4]." - ) - pruning_order: Optional[List[str]] = Field( - default=None, - description="Optional: List of model source names (from YAML 'sources') defining the CA processing order." - ) - - # These are more like fixed properties of the method, not typically changed by user YAML for 'cabs' - # but Pydantic treats them as fields that can be initialized. - # Mergekit's MergeMethod.create will pass YAML params, potentially overriding these if keys match. - # It's safer to have these as fixed return values in name(), pretty_name(), etc. if they are truly static. - # However, to allow Mergekit's create(**kwargs) to work seamlessly if it tries to pass them, - # we keep them as fields with defaults. - method_name_override: Optional[str] = Field(default=None, exclude=True) # For internal use if variants are registered - method_pretty_name_override: Optional[str] = Field(default=None, exclude=True) - method_reference_url_override: Optional[str] = Field(default=None, exclude=True) +# We can keep CABSMerge as a Pydantic BaseModel if Mergekit's registry.py instantiates it directly +# and then MergeMethod.create re-instantiates with YAML parameters. +# Or, if registry.py stores the *class* and MergeMethod.create instantiates it once with YAML params, +# then it also works. Let's assume the latter for now for parameter passing simplicity. +class CABSMerge(MergeMethod, BaseModel, frozen=True): + # These fields capture parameters from YAML that are sibling to 'merge_method: cabs' + # They are used if Mergekit passes them directly to CABSMerge constructor. + # If parameters are *only* from the nested 'parameters:' block, these can be removed, + # and CABSMerge becomes a simpler class just holding name/pretty_name. + # For consistency with how other methods might receive their top-level params via kwargs to __init__ + # by MergeMethod.create, we define them here. + default_n_val: Optional[int] = Field(default=None) + default_m_val: Optional[int] = Field(default=None) + pruning_order: Optional[List[str]] = Field(default=None) + method_name_arg: str = Field("cabs", alias="method_name", exclude=True) + method_pretty_name_arg: Optional[str] = Field("Conflict-Aware N:M Sparsification", alias="method_pretty_name", exclude=True) + method_reference_url_arg: Optional[str] = Field("https://arxiv.org/abs/2503.01874", alias="method_reference_url", exclude=True) - @validator('default_n_m_ratio', pre=True, always=True) - def check_default_n_m_ratio(cls, v: Any) -> Optional[Tuple[int, int]]: - if v is not None: - if not (isinstance(v, (list, tuple)) and len(v) == 2 and - isinstance(v[0], int) and isinstance(v[1], int) and - 0 <= v[0] <= v[1] and v[1] > 0): - raise ValueError( - "default_n_m_ratio must be a tuple/list of two integers [n, m] " - "with 0 <= n <= m and m > 0, or null." - ) - return tuple(v) - return None + @root_validator(pre=False, skip_on_failure=True) + def check_default_n_m_consistency(cls, values: Dict[str, Any]) -> Dict[str, Any]: + n = values.get('default_n_val') # Get from the instance's fields + m = values.get('default_m_val') + if n is not None and m is None: + raise ValueError("If 'default_n_val' is provided, 'default_m_val' must also be provided.") + if m is not None and n is None: + raise ValueError("If 'default_m_val' is provided, 'default_n_val' must also be provided.") + if n is not None and m is not None: + if not (isinstance(n, int) and n >= 0 and isinstance(m, int) and m > 0 and n <= m): + raise ValueError(f"Invalid default n/m values: n={n}, m={m}. Ensure 0 <= n <= m and m > 0.") + return values @override def name(self) -> str: - return self.method_name_override or "cabs" + return self.method_name_arg @override def pretty_name(self) -> Optional[str]: - return self.method_pretty_name_override or "Conflict-Aware and Balanced Sparsification" + return self.method_pretty_name_arg @override def reference_url(self) -> Optional[str]: - return self.method_reference_url_override or "https://arxiv.org/abs/2503.01874" + return self.method_reference_url_arg @override def parameters(self) -> List[ConfigParameterDef]: + # These declare what keys are expected in the YAML block for this method's global config. + # If these keys are siblings to 'merge_method: cabs', Mergekit passes them to __init__. + # If these keys are under a nested 'parameters:' block, Mergekit passes that block as + # the 'parameters' argument to make_task. + # Given KarcherMerge example, they are for the nested 'parameters:' block. return [ - ConfigParameterDef( - name="default_n_m_ratio", type="list[int]", required=False, default_value=None, - description="Optional global default [n, m] ratio. Models can override this. Example: [1, 4]" - ), - ConfigParameterDef( - name="pruning_order", type="list[str]", required=False, default_value=None, - description="Optional: List of model source names (from YAML 'sources') defining the CA processing order." - ), - # These are not typically set by users for the primary "cabs" method, but allow for variants if needed. - # ConfigParameterDef(name="method_name_override", type="str", required=False, advanced=True), - # ConfigParameterDef(name="method_pretty_name_override", type="str", required=False, advanced=True), - # ConfigParameterDef(name="method_reference_url_override", type="str", required=False, advanced=True), + ConfigParameterDef(name="default_n_val", type="int", required=False, default_value=None), + ConfigParameterDef(name="default_m_val", type="int", required=False, default_value=None), + ConfigParameterDef(name="pruning_order", type="list[str]", required=False, default_value=None), ] @override def tensor_parameters(self) -> List[ConfigParameterDef]: return [ - ConfigParameterDef( - name="weight", type="float", required=False, default_value=1.0, - description="Scaling coefficient (lambda) for this model's task vector." - ), - ConfigParameterDef( - name="n_m_ratio", type="list[int]", required=False, default_value=None, - description="Per-model [n, m] ratio for n:m pruning. Overrides global default_n_m_ratio. Example: [1, 2]" - ), + ConfigParameterDef(name="weight", type="float", required=False, default_value=1.0), + ConfigParameterDef(name="n_val", type="int", required=False, default_value=None), + ConfigParameterDef(name="m_val", type="int", required=False, default_value=None), ] @override @@ -143,138 +112,187 @@ def make_task( self, output_weight: WeightInfo, tensors: MergeTensorInput, - base_model: Optional[ModelReference], - parameters: ImmutableMap[str, Any], # parameters from YAML for THIS method invocation + base_model: Optional[ModelReference], + parameters: ImmutableMap[str, Any], # This map IS from the nested 'parameters:' block in YAML tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Task: if base_model is None: - logging.error("CABS merge requires a 'base_model' to be specified in the YAML.") raise ValueError("CABS merge requires a 'base_model'.") - # 'self' is the instance from STATIC_MERGE_METHODS. - # 'parameters' contains the YAML overrides for default_n_m_ratio, pruning_order, etc. - # We should use 'parameters' to construct the CABSTask or configure the instance. - # Mergekit's MergeMethod.create typically handles creating a new instance with YAML params. - # So, self.default_n_m_ratio etc. on *this specific 'self'* instance will be the final ones. + # Extract global CABS parameters SAFELY from the 'parameters' ImmutableMap + default_n_val_global: Optional[int] = None + if "default_n_val" in parameters: + val = parameters["default_n_val"] + if isinstance(val, int): default_n_val_global = val + elif val is not None: logging.warning(f"Expected int for default_n_val, got {type(val)}") + + default_m_val_global: Optional[int] = None + if "default_m_val" in parameters: + val = parameters["default_m_val"] + if isinstance(val, int): default_m_val_global = val + elif val is not None: logging.warning(f"Expected int for default_m_val, got {type(val)}") + + pruning_order_global: Optional[List[str]] = None + if "pruning_order" in parameters: + val = parameters["pruning_order"] + if isinstance(val, list) and all(isinstance(s, str) for s in val): + pruning_order_global = val + elif val is not None: + logging.warning(f"Expected list of strings for pruning_order, got {type(val)}") + # Validation for n and m consistency if both are provided globally + if default_n_val_global is not None and default_m_val_global is None: + raise ValueError("If 'default_n_val' is provided in global parameters, 'default_m_val' must also be provided.") + if default_m_val_global is not None and default_n_val_global is None: + raise ValueError("If 'default_m_val' is provided in global parameters, 'default_n_val' must also be provided.") + if default_n_val_global is not None and default_m_val_global is not None: # Both are provided + if not (default_n_val_global >= 0 and default_m_val_global > 0 and default_n_val_global <= default_m_val_global): + raise ValueError(f"Invalid global default n/m values: n={default_n_val_global}, m={default_m_val_global}. " + "Ensure 0 <= n <= m and m > 0.") + return CABSTask( - method_config=self, # 'self' is the correctly configured CABSMerge instance + global_default_n_val=default_n_val_global, + global_default_m_val=default_m_val_global, + global_pruning_order=pruning_order_global, tensors_input=tensors, - base_model_ref=base_model, + base_model_ref=base_model, current_weight_info=output_weight, per_model_tensor_params=tensor_parameters, ) class CABSTask(Task[torch.Tensor]): - method_config: CABSMerge + global_default_n_val: Optional[int] + global_default_m_val: Optional[int] + global_pruning_order: Optional[List[str]] + tensors_input: MergeTensorInput - base_model_ref: ModelReference + base_model_ref: ModelReference current_weight_info: WeightInfo per_model_tensor_params: ImmutableMap[ModelReference, ImmutableMap[str, Any]] - FALLBACK_N_M_RATIO: Tuple[int, int] = (1, 4) + FALLBACK_N_VAL: int = 1 + FALLBACK_M_VAL: int = 4 @override - def uses_accelerator(self) -> bool: - return True - + def uses_accelerator(self) -> bool: return True @override - def arguments(self) -> Dict[str, Task]: - return {"tensors_arg": self.tensors_input} + def arguments(self) -> Dict[str, Task]: return {"tensors_arg": self.tensors_input} - def _resolve_nm_ratio_for_model(self, model_ref: ModelReference) -> Tuple[int, int]: - current_model_params = self.per_model_tensor_params.get(model_ref, ImmutableMap({})) - per_model_nm_ratio_raw = current_model_params.get("n_m_ratio") +# In CABSTask class, within cabs.py - if per_model_nm_ratio_raw is not None: - if not (isinstance(per_model_nm_ratio_raw, (list, tuple)) and len(per_model_nm_ratio_raw) == 2 and - isinstance(per_model_nm_ratio_raw[0], int) and isinstance(per_model_nm_ratio_raw[1], int) and - 0 <= per_model_nm_ratio_raw[0] <= per_model_nm_ratio_raw[1] and per_model_nm_ratio_raw[1] > 0): - logging.warning(f"Invalid n_m_ratio {per_model_nm_ratio_raw} for model {model_ref.name} " - f"on tensor {self.current_weight_info.name}. " - f"Falling back.") - else: - return int(per_model_nm_ratio_raw[0]), int(per_model_nm_ratio_raw[1]) + def _resolve_n_and_m_for_model(self, model_ref: ModelReference) -> Tuple[int, int]: + per_model_n_raw: Any = None # Use Any to catch the actual type passed + per_model_m_raw: Any = None + + model_identifier_str = str(model_ref.model) + logging.debug(f"Resolving n and m for model: {model_identifier_str}") + + if model_ref in self.per_model_tensor_params: + current_model_inner_params_map = self.per_model_tensor_params[model_ref] + logging.debug(f" Per-model params for {model_identifier_str}: {dict(current_model_inner_params_map)}") + + if "n_val" in current_model_inner_params_map: + per_model_n_raw = current_model_inner_params_map["n_val"] + logging.debug(f" Raw per-model n_val: {per_model_n_raw} (type: {type(per_model_n_raw)})") + if "m_val" in current_model_inner_params_map: + per_model_m_raw = current_model_inner_params_map["m_val"] + logging.debug(f" Raw per-model m_val: {per_model_m_raw} (type: {type(per_model_m_raw)})") + + # Attempt to convert to int if they are floats representing whole numbers + def try_convert_to_int(val: Any, name: str) -> Optional[int]: + if isinstance(val, int): + return val + if isinstance(val, float): + if val.is_integer(): + return int(val) + else: + logging.warning(f" Cannot convert per-model {name} '{val}' to int as it's a non-whole float.") + return None + if val is not None: # Log if it's some other unexpected type + logging.warning(f" Unexpected type for per-model {name}: {type(val)}. Expected int or float representing int.") + return None - if self.method_config.default_n_m_ratio is not None: # Use from configured instance - return self.method_config.default_n_m_ratio + n_candidate: Optional[int] = None + m_candidate: Optional[int] = None - logging.warning(f"No n_m_ratio specified for model {model_ref.name} and no global default " - f"for tensor {self.current_weight_info.name}. " - f"Using hardcoded fallback: {self.FALLBACK_N_M_RATIO}.") - return self.FALLBACK_N_M_RATIO + if per_model_n_raw is not None: + n_candidate = try_convert_to_int(per_model_n_raw, "n_val") + if per_model_m_raw is not None: + m_candidate = try_convert_to_int(per_model_m_raw, "m_val") + + # Check if per-model n and m are consistently provided and valid AFTER conversion attempt + if n_candidate is not None and m_candidate is not None: + if n_candidate >= 0 and m_candidate > 0 and n_candidate <= m_candidate: + logging.debug(f" Using per-model n_val={n_candidate}, m_val={m_candidate} for {model_identifier_str}") + return n_candidate, m_candidate + else: + logging.warning(f" Invalid per-model n_val/m_val after conversion: n={n_candidate}, m={m_candidate} " + f"for model {model_identifier_str}. Will try global defaults.") + # If only one was provided or conversion failed for one, it's an incomplete/invalid pair + elif n_candidate is not None or m_candidate is not None: + logging.warning(f" Incomplete or invalid per-model n_val/m_val after conversion: n={n_candidate}, m={m_candidate} " + f"for model {model_identifier_str}. Both valid integers are required if one is set. Will try global defaults.") + + # Try global default parameters if per-model not valid or not fully set + if self.global_default_n_val is not None and self.global_default_m_val is not None: + # Global defaults already validated by CABSMerge.make_task + logging.debug(f" Using global default_n_val={self.global_default_n_val}, " + f"default_m_val={self.global_default_m_val} for {model_identifier_str}") + return self.global_default_n_val, self.global_default_m_val + + logging.warning(f" No valid per-model or global default n/m values for model {model_identifier_str} " + f"on tensor {self.current_weight_info.name}. " + f"Using hardcoded fallback: n={self.FALLBACK_N_VAL}, m={self.FALLBACK_M_VAL}.") + return self.FALLBACK_N_VAL, self.FALLBACK_M_VAL @override - def execute( - self, - tensors_arg: Dict[ModelReference, torch.Tensor], - **_kwargs, - ) -> torch.Tensor: + def execute(self, tensors_arg: Dict[ModelReference, torch.Tensor], **_kwargs) -> torch.Tensor: + base_model_identifier_str = str(self.base_model_ref.model) if self.base_model_ref not in tensors_arg: - logging.error(f"Base model '{self.base_model_ref.name}' tensor not found for weight '{self.current_weight_info.name}'.") - device_str = self.current_weight_info.device_str() if hasattr(self.current_weight_info, 'device_str') else 'cpu' - dtype_val = self.current_weight_info.dtype if hasattr(self.current_weight_info, 'dtype') else torch.float32 + logging.error(f"Base model '{base_model_identifier_str}' not found for '{self.current_weight_info.name}'.") + device_str = self.current_weight_info.device_str() if hasattr(self.current_weight_info,'device_str') and callable(self.current_weight_info.device_str) else 'cpu' + dtype_val = self.current_weight_info.dtype if hasattr(self.current_weight_info,'dtype') else torch.float32 return torch.empty(0, device=torch.device(device_str), dtype=dtype_val) - target_device = tensors_arg[self.base_model_ref].device - target_dtype = tensors_arg[self.base_model_ref].dtype - merged_tensor_accumulator = tensors_arg[self.base_model_ref].clone().to(device=target_device, dtype=target_dtype) - + target_device = tensors_arg[self.base_model_ref].device; target_dtype = tensors_arg[self.base_model_ref].dtype + merged_tensor_accumulator = tensors_arg[self.base_model_ref].clone().to(device=target_device,dtype=target_dtype) ordered_model_refs_for_ca: List[ModelReference] = [] - model_ref_by_name_map: Dict[str, ModelReference] = { ref.name: ref for ref in tensors_arg.keys() } + model_ref_by_string_id: Dict[str, ModelReference] = {str(ref.model): ref for ref in tensors_arg.keys()} - current_pruning_order_names = self.method_config.pruning_order # Get from configured instance - if current_pruning_order_names: - for name in current_pruning_order_names: - if name == self.base_model_ref.name: - continue - if name in model_ref_by_name_map: - ordered_model_refs_for_ca.append(model_ref_by_name_map[name]) - else: - logging.warning(f"Model '{name}' from pruning_order not found among available tensors " - f"for weight '{self.current_weight_info.name}', skipping this entry in order.") + current_pruning_order_strings = self.global_pruning_order + if current_pruning_order_strings: + for id_str_in_order in current_pruning_order_strings: + if id_str_in_order == base_model_identifier_str: continue + if id_str_in_order in model_ref_by_string_id: ordered_model_refs_for_ca.append(model_ref_by_string_id[id_str_in_order]) + else: logging.warning(f"Model ID '{id_str_in_order}' from order not found for '{self.current_weight_info.name}'.") else: - sorted_non_base_names = sorted([ref.name for ref in tensors_arg.keys() if ref != self.base_model_ref]) - for name in sorted_non_base_names: - if name in model_ref_by_name_map: - ordered_model_refs_for_ca.append(model_ref_by_name_map[name]) + sorted_non_base_string_ids = sorted([str(ref.model) for ref in tensors_arg.keys() if str(ref.model) != base_model_identifier_str]) + for id_str in sorted_non_base_string_ids: + if id_str in model_ref_by_string_id: ordered_model_refs_for_ca.append(model_ref_by_string_id[id_str]) if not ordered_model_refs_for_ca: - logging.info(f"No non-base models to merge for weight '{self.current_weight_info.name}'. " - "Returning base tensor.") - return merged_tensor_accumulator - - cumulative_param_mask = torch.zeros_like(merged_tensor_accumulator, dtype=torch.bool, device=target_device) + logging.info(f"No non-base models for '{self.current_weight_info.name}'. Returning base."); return merged_tensor_accumulator + cumulative_param_mask = torch.zeros_like(merged_tensor_accumulator,dtype=torch.bool,device=target_device) for model_ref_current in ordered_model_refs_for_ca: - if model_ref_current not in tensors_arg: - logging.warning(f"Tensor for model '{model_ref_current.name}' became unavailable during processing " - f"for weight '{self.current_weight_info.name}', skipping.") - continue - - fine_tuned_tensor_val = tensors_arg[model_ref_current].to(device=target_device, dtype=target_dtype) - base_tensor_for_diff = tensors_arg[self.base_model_ref].to(device=target_device, dtype=target_dtype) - - current_model_params_map = self.per_model_tensor_params.get(model_ref_current, ImmutableMap({})) - scaling_coefficient = float(current_model_params_map.get("weight", 1.0)) - n_val_current, m_val_current = self._resolve_nm_ratio_for_model(model_ref_current) + if model_ref_current not in tensors_arg: logging.warning(f"Tensor for '{str(model_ref_current.model)}' unavailable for '{self.current_weight_info.name}'."); continue + fine_tuned_tensor_val = tensors_arg[model_ref_current].to(device=target_device,dtype=target_dtype) + base_tensor_for_diff = tensors_arg[self.base_model_ref].to(device=target_device,dtype=target_dtype) + scaling_coefficient = 1.0 + if model_ref_current in self.per_model_tensor_params: + inner_params = self.per_model_tensor_params[model_ref_current] + if "weight" in inner_params: + val = inner_params["weight"] + if isinstance(val, (int,float)): scaling_coefficient = float(val) + else: logging.warning(f"Expected float for per-model weight, got {type(val)}") + n_val_current,m_val_current = self._resolve_n_and_m_for_model(model_ref_current) task_vector_val = fine_tuned_tensor_val - base_tensor_for_diff available_params_mask = ~cumulative_param_mask candidate_task_vector = task_vector_val * available_params_mask.to(task_vector_val.dtype) - - pruned_task_vector, newly_retained_mask = prune_n_m_structural( - candidate_task_vector, - n_val_current, - m_val_current - ) - + pruned_task_vector,newly_retained_mask = prune_n_m_structural(candidate_task_vector,n_val_current,m_val_current) merged_tensor_accumulator += scaling_coefficient * pruned_task_vector.to(merged_tensor_accumulator.dtype) - cumulative_param_mask = torch.logical_or( - cumulative_param_mask, - newly_retained_mask.to(device=cumulative_param_mask.device) - ) - + cumulative_param_mask = torch.logical_or(cumulative_param_mask,newly_retained_mask.to(device=cumulative_param_mask.device)) return merged_tensor_accumulator @override From 8fb2f857a13c2dcda746681487c4861fd8bd199c Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 15:35:37 +0800 Subject: [PATCH 08/11] Update cabs.yml --- examples/cabs.yml | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/examples/cabs.yml b/examples/cabs.yml index dc73588b..89acea42 100644 --- a/examples/cabs.yml +++ b/examples/cabs.yml @@ -1,25 +1,26 @@ -base_model: mistral_base -merge_method: cabs -default_n_m_ratio: [1, 4] -pruning_order: - - zephyr_alpha - - zephyr_beta +# cabs_nm_split_test.yaml -sources: - mistral_base: - path: /home/yangzz/models_test/Mistral-7b-v0.1/ - zephyr_alpha: - path: /home/yangzz/models_test/zephyr-7b-alpha/ +models: + - model: ./models_test/Mistral-7b-v0.1/ # Identifier is the path string + - model: ./models_test/zephyr-7b-alpha/ # Identifier is the path string parameters: - weight: 0.6 - n_m_ratio: [8, 32] - zephyr_beta: - path: /home/yangzz/models_test/zephyr-7b-beta/ - parameters: - weight: 0.4 - n_m_ratio: [8, 32] + weight: 0.6 + n_val: 8 # Per-model n + m_val: 32 # Per-model m + - model: ./models_test/zephyr-7b-beta/ # Identifier is the path string + parameters: + weight: 0.4 + n_val: 8 # Per-model n + m_val: 32 # Per-model m + # n_val and m_val not set for zephyr_beta, will use global defaults + +merge_method: cabs +# Global method parameters, sibling to merge_method +default_n_val: 1 # Global default n +default_m_val: 4 # Global default m +pruning_order: + - ./models_test/zephyr-7b-alpha/ + - ./models_test/zephyr-7b-beta/ -out_path: ./merged_cabs_zephyr_output -dtype: bfloat16 -low_cpu_memory: true -copy_tokenizer: true +base_model: ./models_test/Mistral-7b-v0.1/ +dtype: bfloat16 From 6b05fb714db121675d1a9ac910ee2bba54c58bbb Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 15:45:38 +0800 Subject: [PATCH 09/11] Update cabs.yml --- examples/cabs.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/cabs.yml b/examples/cabs.yml index 89acea42..bd1a013f 100644 --- a/examples/cabs.yml +++ b/examples/cabs.yml @@ -1,5 +1,3 @@ -# cabs_nm_split_test.yaml - models: - model: ./models_test/Mistral-7b-v0.1/ # Identifier is the path string - model: ./models_test/zephyr-7b-alpha/ # Identifier is the path string @@ -16,8 +14,8 @@ models: merge_method: cabs # Global method parameters, sibling to merge_method -default_n_val: 1 # Global default n -default_m_val: 4 # Global default m +default_n_val: 8 # Global default n +default_m_val: 32 # Global default m pruning_order: - ./models_test/zephyr-7b-alpha/ - ./models_test/zephyr-7b-beta/ From 5068a33c1e551d2c370a70e9f36685cea938105e Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Fri, 9 May 2025 15:46:36 +0800 Subject: [PATCH 10/11] Update cabs.py --- mergekit/merge_methods/cabs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mergekit/merge_methods/cabs.py b/mergekit/merge_methods/cabs.py index f5f29039..317e8a14 100644 --- a/mergekit/merge_methods/cabs.py +++ b/mergekit/merge_methods/cabs.py @@ -1,4 +1,3 @@ -# mergekit/merge_methods/cabs.py import logging import torch from typing import List, Dict, Tuple, Any, Optional From 915219ac96978109ab5c58f956415b69effa145c Mon Sep 17 00:00:00 2001 From: zongzhenyang Date: Sun, 11 May 2025 13:19:37 +0800 Subject: [PATCH 11/11] isort and black hooks2 --- examples/cabs.yml | 8 +- mergekit/merge_methods/cabs.py | 395 ++++++++++++++++++++--------- mergekit/merge_methods/registry.py | 2 +- 3 files changed, 286 insertions(+), 119 deletions(-) diff --git a/examples/cabs.yml b/examples/cabs.yml index bd1a013f..aa35ddc9 100644 --- a/examples/cabs.yml +++ b/examples/cabs.yml @@ -1,12 +1,12 @@ models: - model: ./models_test/Mistral-7b-v0.1/ # Identifier is the path string - model: ./models_test/zephyr-7b-alpha/ # Identifier is the path string - parameters: + parameters: weight: 0.6 n_val: 8 # Per-model n m_val: 32 # Per-model m - model: ./models_test/zephyr-7b-beta/ # Identifier is the path string - parameters: + parameters: weight: 0.4 n_val: 8 # Per-model n m_val: 32 # Per-model m @@ -16,9 +16,9 @@ merge_method: cabs # Global method parameters, sibling to merge_method default_n_val: 8 # Global default n default_m_val: 32 # Global default m -pruning_order: +pruning_order: - ./models_test/zephyr-7b-alpha/ - ./models_test/zephyr-7b-beta/ -base_model: ./models_test/Mistral-7b-v0.1/ +base_model: ./models_test/Mistral-7b-v0.1/ dtype: bfloat16 diff --git a/mergekit/merge_methods/cabs.py b/mergekit/merge_methods/cabs.py index 317e8a14..84214f21 100644 --- a/mergekit/merge_methods/cabs.py +++ b/mergekit/merge_methods/cabs.py @@ -1,8 +1,14 @@ import logging +from typing import Any, Dict, List, Optional, Tuple + import torch -from typing import List, Dict, Tuple, Any, Optional -from pydantic import BaseModel, Field, validator, root_validator # Keep BaseModel for CABSMerge for now if registry expects instance -from typing_extensions import override, Literal +from pydantic import ( # Keep BaseModel for CABSMerge for now if registry expects instance + BaseModel, + Field, + root_validator, + validator, +) +from typing_extensions import Literal, override from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference @@ -13,39 +19,60 @@ MergeTensorInput, ) + # --- Helper function for n:m structural pruning (remains the same) --- def prune_n_m_structural( - tensor: torch.Tensor, - n_val: int, - m_val: int + tensor: torch.Tensor, n_val: int, m_val: int ) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(tensor, torch.Tensor) or tensor.numel() == 0: return tensor.clone(), torch.zeros_like(tensor, dtype=torch.bool) - original_shape = tensor.shape; device = tensor.device - flat_tensor_orig_values = tensor.flatten().clone(); num_elements = flat_tensor_orig_values.numel() - if m_val <= 0: logging.error(f"Tensor {original_shape}: m_val ({m_val}) must be positive."); return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) - if n_val < 0 or n_val > m_val: logging.error(f"Tensor {original_shape}: n_val ({n_val}) invalid."); return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) - if n_val == 0: return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) - if n_val == m_val: return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) - padding = (m_val-(num_elements % m_val))%m_val - if padding > 0: flat_tensor_padded = torch.cat((flat_tensor_orig_values, torch.zeros(padding,device=device,dtype=tensor.dtype))) - else: flat_tensor_padded = flat_tensor_orig_values + original_shape = tensor.shape + device = tensor.device + flat_tensor_orig_values = tensor.flatten().clone() + num_elements = flat_tensor_orig_values.numel() + if m_val <= 0: + logging.error(f"Tensor {original_shape}: m_val ({m_val}) must be positive.") + return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + if n_val < 0 or n_val > m_val: + logging.error(f"Tensor {original_shape}: n_val ({n_val}) invalid.") + return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + if n_val == 0: + return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) + if n_val == m_val: + return tensor.clone(), torch.ones_like(tensor, dtype=torch.bool) + padding = (m_val - (num_elements % m_val)) % m_val + if padding > 0: + flat_tensor_padded = torch.cat( + ( + flat_tensor_orig_values, + torch.zeros(padding, device=device, dtype=tensor.dtype), + ) + ) + else: + flat_tensor_padded = flat_tensor_orig_values reshaped_tensor = flat_tensor_padded.reshape(-1, m_val) - if reshaped_tensor.numel()==0: return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) - magnitudes = torch.abs(reshaped_tensor); _,top_n_indices_in_blocks = torch.topk(magnitudes,k=n_val,dim=1) - nm_mask_blocks = torch.zeros_like(reshaped_tensor,dtype=torch.bool,device=device); nm_mask_blocks.scatter_(1,top_n_indices_in_blocks,True) + if reshaped_tensor.numel() == 0: + return torch.zeros_like(tensor), torch.zeros_like(tensor, dtype=torch.bool) + magnitudes = torch.abs(reshaped_tensor) + _, top_n_indices_in_blocks = torch.topk(magnitudes, k=n_val, dim=1) + nm_mask_blocks = torch.zeros_like(reshaped_tensor, dtype=torch.bool, device=device) + nm_mask_blocks.scatter_(1, top_n_indices_in_blocks, True) nm_mask_flat_padded = nm_mask_blocks.flatten() - if padding > 0: nm_mask_unpadded = nm_mask_flat_padded[:-padding] - else: nm_mask_unpadded = nm_mask_flat_padded - final_mask_reshaped = nm_mask_unpadded.reshape(original_shape); final_pruned_tensor = tensor * final_mask_reshaped + if padding > 0: + nm_mask_unpadded = nm_mask_flat_padded[:-padding] + else: + nm_mask_unpadded = nm_mask_flat_padded + final_mask_reshaped = nm_mask_unpadded.reshape(original_shape) + final_pruned_tensor = tensor * final_mask_reshaped return final_pruned_tensor, final_mask_reshaped + # --- Mergekit Method Definition --- # We can keep CABSMerge as a Pydantic BaseModel if Mergekit's registry.py instantiates it directly # and then MergeMethod.create re-instantiates with YAML parameters. # Or, if registry.py stores the *class* and MergeMethod.create instantiates it once with YAML params, # then it also works. Let's assume the latter for now for parameter passing simplicity. -class CABSMerge(MergeMethod, BaseModel, frozen=True): +class CABSMerge(MergeMethod, BaseModel, frozen=True): # These fields capture parameters from YAML that are sibling to 'merge_method: cabs' # They are used if Mergekit passes them directly to CABSMerge constructor. # If parameters are *only* from the nested 'parameters:' block, these can be removed, @@ -56,21 +83,37 @@ class CABSMerge(MergeMethod, BaseModel, frozen=True): default_m_val: Optional[int] = Field(default=None) pruning_order: Optional[List[str]] = Field(default=None) - method_name_arg: str = Field("cabs", alias="method_name", exclude=True) - method_pretty_name_arg: Optional[str] = Field("Conflict-Aware N:M Sparsification", alias="method_pretty_name", exclude=True) - method_reference_url_arg: Optional[str] = Field("https://arxiv.org/abs/2503.01874", alias="method_reference_url", exclude=True) + method_name_arg: str = Field("cabs", alias="method_name", exclude=True) + method_pretty_name_arg: Optional[str] = Field( + "Conflict-Aware N:M Sparsification", alias="method_pretty_name", exclude=True + ) + method_reference_url_arg: Optional[str] = Field( + "https://arxiv.org/abs/2503.01874", alias="method_reference_url", exclude=True + ) @root_validator(pre=False, skip_on_failure=True) def check_default_n_m_consistency(cls, values: Dict[str, Any]) -> Dict[str, Any]: - n = values.get('default_n_val') # Get from the instance's fields - m = values.get('default_m_val') + n = values.get("default_n_val") # Get from the instance's fields + m = values.get("default_m_val") if n is not None and m is None: - raise ValueError("If 'default_n_val' is provided, 'default_m_val' must also be provided.") + raise ValueError( + "If 'default_n_val' is provided, 'default_m_val' must also be provided." + ) if m is not None and n is None: - raise ValueError("If 'default_m_val' is provided, 'default_n_val' must also be provided.") + raise ValueError( + "If 'default_m_val' is provided, 'default_n_val' must also be provided." + ) if n is not None and m is not None: - if not (isinstance(n, int) and n >= 0 and isinstance(m, int) and m > 0 and n <= m): - raise ValueError(f"Invalid default n/m values: n={n}, m={m}. Ensure 0 <= n <= m and m > 0.") + if not ( + isinstance(n, int) + and n >= 0 + and isinstance(m, int) + and m > 0 + and n <= m + ): + raise ValueError( + f"Invalid default n/m values: n={n}, m={m}. Ensure 0 <= n <= m and m > 0." + ) return values @override @@ -80,7 +123,7 @@ def name(self) -> str: @override def pretty_name(self) -> Optional[str]: return self.method_pretty_name_arg - + @override def reference_url(self) -> Optional[str]: return self.method_reference_url_arg @@ -93,43 +136,64 @@ def parameters(self) -> List[ConfigParameterDef]: # the 'parameters' argument to make_task. # Given KarcherMerge example, they are for the nested 'parameters:' block. return [ - ConfigParameterDef(name="default_n_val", type="int", required=False, default_value=None), - ConfigParameterDef(name="default_m_val", type="int", required=False, default_value=None), - ConfigParameterDef(name="pruning_order", type="list[str]", required=False, default_value=None), + ConfigParameterDef( + name="default_n_val", type="int", required=False, default_value=None + ), + ConfigParameterDef( + name="default_m_val", type="int", required=False, default_value=None + ), + ConfigParameterDef( + name="pruning_order", + type="list[str]", + required=False, + default_value=None, + ), ] @override def tensor_parameters(self) -> List[ConfigParameterDef]: return [ - ConfigParameterDef(name="weight", type="float", required=False, default_value=1.0), - ConfigParameterDef(name="n_val", type="int", required=False, default_value=None), - ConfigParameterDef(name="m_val", type="int", required=False, default_value=None), + ConfigParameterDef( + name="weight", type="float", required=False, default_value=1.0 + ), + ConfigParameterDef( + name="n_val", type="int", required=False, default_value=None + ), + ConfigParameterDef( + name="m_val", type="int", required=False, default_value=None + ), ] - + @override def make_task( self, output_weight: WeightInfo, tensors: MergeTensorInput, - base_model: Optional[ModelReference], - parameters: ImmutableMap[str, Any], # This map IS from the nested 'parameters:' block in YAML + base_model: Optional[ModelReference], + parameters: ImmutableMap[ + str, Any + ], # This map IS from the nested 'parameters:' block in YAML tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Task: if base_model is None: raise ValueError("CABS merge requires a 'base_model'.") - + # Extract global CABS parameters SAFELY from the 'parameters' ImmutableMap default_n_val_global: Optional[int] = None if "default_n_val" in parameters: val = parameters["default_n_val"] - if isinstance(val, int): default_n_val_global = val - elif val is not None: logging.warning(f"Expected int for default_n_val, got {type(val)}") - + if isinstance(val, int): + default_n_val_global = val + elif val is not None: + logging.warning(f"Expected int for default_n_val, got {type(val)}") + default_m_val_global: Optional[int] = None if "default_m_val" in parameters: val = parameters["default_m_val"] - if isinstance(val, int): default_m_val_global = val - elif val is not None: logging.warning(f"Expected int for default_m_val, got {type(val)}") + if isinstance(val, int): + default_m_val_global = val + elif val is not None: + logging.warning(f"Expected int for default_m_val, got {type(val)}") pruning_order_global: Optional[List[str]] = None if "pruning_order" in parameters: @@ -137,65 +201,89 @@ def make_task( if isinstance(val, list) and all(isinstance(s, str) for s in val): pruning_order_global = val elif val is not None: - logging.warning(f"Expected list of strings for pruning_order, got {type(val)}") + logging.warning( + f"Expected list of strings for pruning_order, got {type(val)}" + ) # Validation for n and m consistency if both are provided globally if default_n_val_global is not None and default_m_val_global is None: - raise ValueError("If 'default_n_val' is provided in global parameters, 'default_m_val' must also be provided.") + raise ValueError( + "If 'default_n_val' is provided in global parameters, 'default_m_val' must also be provided." + ) if default_m_val_global is not None and default_n_val_global is None: - raise ValueError("If 'default_m_val' is provided in global parameters, 'default_n_val' must also be provided.") - if default_n_val_global is not None and default_m_val_global is not None: # Both are provided - if not (default_n_val_global >= 0 and default_m_val_global > 0 and default_n_val_global <= default_m_val_global): - raise ValueError(f"Invalid global default n/m values: n={default_n_val_global}, m={default_m_val_global}. " - "Ensure 0 <= n <= m and m > 0.") - + raise ValueError( + "If 'default_m_val' is provided in global parameters, 'default_n_val' must also be provided." + ) + if ( + default_n_val_global is not None and default_m_val_global is not None + ): # Both are provided + if not ( + default_n_val_global >= 0 + and default_m_val_global > 0 + and default_n_val_global <= default_m_val_global + ): + raise ValueError( + f"Invalid global default n/m values: n={default_n_val_global}, m={default_m_val_global}. " + "Ensure 0 <= n <= m and m > 0." + ) + return CABSTask( global_default_n_val=default_n_val_global, global_default_m_val=default_m_val_global, global_pruning_order=pruning_order_global, tensors_input=tensors, - base_model_ref=base_model, + base_model_ref=base_model, current_weight_info=output_weight, per_model_tensor_params=tensor_parameters, ) + class CABSTask(Task[torch.Tensor]): global_default_n_val: Optional[int] global_default_m_val: Optional[int] global_pruning_order: Optional[List[str]] - + tensors_input: MergeTensorInput - base_model_ref: ModelReference + base_model_ref: ModelReference current_weight_info: WeightInfo per_model_tensor_params: ImmutableMap[ModelReference, ImmutableMap[str, Any]] - - FALLBACK_N_VAL: int = 1 + + FALLBACK_N_VAL: int = 1 FALLBACK_M_VAL: int = 4 @override - def uses_accelerator(self) -> bool: return True + def uses_accelerator(self) -> bool: + return True + @override - def arguments(self) -> Dict[str, Task]: return {"tensors_arg": self.tensors_input} + def arguments(self) -> Dict[str, Task]: + return {"tensors_arg": self.tensors_input} -# In CABSTask class, within cabs.py + # In CABSTask class, within cabs.py def _resolve_n_and_m_for_model(self, model_ref: ModelReference) -> Tuple[int, int]: - per_model_n_raw: Any = None # Use Any to catch the actual type passed + per_model_n_raw: Any = None # Use Any to catch the actual type passed per_model_m_raw: Any = None - + model_identifier_str = str(model_ref.model) logging.debug(f"Resolving n and m for model: {model_identifier_str}") if model_ref in self.per_model_tensor_params: current_model_inner_params_map = self.per_model_tensor_params[model_ref] - logging.debug(f" Per-model params for {model_identifier_str}: {dict(current_model_inner_params_map)}") - + logging.debug( + f" Per-model params for {model_identifier_str}: {dict(current_model_inner_params_map)}" + ) + if "n_val" in current_model_inner_params_map: per_model_n_raw = current_model_inner_params_map["n_val"] - logging.debug(f" Raw per-model n_val: {per_model_n_raw} (type: {type(per_model_n_raw)})") + logging.debug( + f" Raw per-model n_val: {per_model_n_raw} (type: {type(per_model_n_raw)})" + ) if "m_val" in current_model_inner_params_map: per_model_m_raw = current_model_inner_params_map["m_val"] - logging.debug(f" Raw per-model m_val: {per_model_m_raw} (type: {type(per_model_m_raw)})") + logging.debug( + f" Raw per-model m_val: {per_model_m_raw} (type: {type(per_model_m_raw)})" + ) # Attempt to convert to int if they are floats representing whole numbers def try_convert_to_int(val: Any, name: str) -> Optional[int]: @@ -205,10 +293,14 @@ def try_convert_to_int(val: Any, name: str) -> Optional[int]: if val.is_integer(): return int(val) else: - logging.warning(f" Cannot convert per-model {name} '{val}' to int as it's a non-whole float.") - return None - if val is not None: # Log if it's some other unexpected type - logging.warning(f" Unexpected type for per-model {name}: {type(val)}. Expected int or float representing int.") + logging.warning( + f" Cannot convert per-model {name} '{val}' to int as it's a non-whole float." + ) + return None + if val is not None: # Log if it's some other unexpected type + logging.warning( + f" Unexpected type for per-model {name}: {type(val)}. Expected int or float representing int." + ) return None n_candidate: Optional[int] = None @@ -222,76 +314,151 @@ def try_convert_to_int(val: Any, name: str) -> Optional[int]: # Check if per-model n and m are consistently provided and valid AFTER conversion attempt if n_candidate is not None and m_candidate is not None: if n_candidate >= 0 and m_candidate > 0 and n_candidate <= m_candidate: - logging.debug(f" Using per-model n_val={n_candidate}, m_val={m_candidate} for {model_identifier_str}") + logging.debug( + f" Using per-model n_val={n_candidate}, m_val={m_candidate} for {model_identifier_str}" + ) return n_candidate, m_candidate else: - logging.warning(f" Invalid per-model n_val/m_val after conversion: n={n_candidate}, m={m_candidate} " - f"for model {model_identifier_str}. Will try global defaults.") + logging.warning( + f" Invalid per-model n_val/m_val after conversion: n={n_candidate}, m={m_candidate} " + f"for model {model_identifier_str}. Will try global defaults." + ) # If only one was provided or conversion failed for one, it's an incomplete/invalid pair - elif n_candidate is not None or m_candidate is not None: - logging.warning(f" Incomplete or invalid per-model n_val/m_val after conversion: n={n_candidate}, m={m_candidate} " - f"for model {model_identifier_str}. Both valid integers are required if one is set. Will try global defaults.") - + elif n_candidate is not None or m_candidate is not None: + logging.warning( + f" Incomplete or invalid per-model n_val/m_val after conversion: n={n_candidate}, m={m_candidate} " + f"for model {model_identifier_str}. Both valid integers are required if one is set. Will try global defaults." + ) + # Try global default parameters if per-model not valid or not fully set - if self.global_default_n_val is not None and self.global_default_m_val is not None: + if ( + self.global_default_n_val is not None + and self.global_default_m_val is not None + ): # Global defaults already validated by CABSMerge.make_task - logging.debug(f" Using global default_n_val={self.global_default_n_val}, " - f"default_m_val={self.global_default_m_val} for {model_identifier_str}") + logging.debug( + f" Using global default_n_val={self.global_default_n_val}, " + f"default_m_val={self.global_default_m_val} for {model_identifier_str}" + ) return self.global_default_n_val, self.global_default_m_val - logging.warning(f" No valid per-model or global default n/m values for model {model_identifier_str} " - f"on tensor {self.current_weight_info.name}. " - f"Using hardcoded fallback: n={self.FALLBACK_N_VAL}, m={self.FALLBACK_M_VAL}.") + logging.warning( + f" No valid per-model or global default n/m values for model {model_identifier_str} " + f"on tensor {self.current_weight_info.name}. " + f"Using hardcoded fallback: n={self.FALLBACK_N_VAL}, m={self.FALLBACK_M_VAL}." + ) return self.FALLBACK_N_VAL, self.FALLBACK_M_VAL @override - def execute(self, tensors_arg: Dict[ModelReference, torch.Tensor], **_kwargs) -> torch.Tensor: + def execute( + self, tensors_arg: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: base_model_identifier_str = str(self.base_model_ref.model) if self.base_model_ref not in tensors_arg: - logging.error(f"Base model '{base_model_identifier_str}' not found for '{self.current_weight_info.name}'.") - device_str = self.current_weight_info.device_str() if hasattr(self.current_weight_info,'device_str') and callable(self.current_weight_info.device_str) else 'cpu' - dtype_val = self.current_weight_info.dtype if hasattr(self.current_weight_info,'dtype') else torch.float32 + logging.error( + f"Base model '{base_model_identifier_str}' not found for '{self.current_weight_info.name}'." + ) + device_str = ( + self.current_weight_info.device_str() + if hasattr(self.current_weight_info, "device_str") + and callable(self.current_weight_info.device_str) + else "cpu" + ) + dtype_val = ( + self.current_weight_info.dtype + if hasattr(self.current_weight_info, "dtype") + else torch.float32 + ) return torch.empty(0, device=torch.device(device_str), dtype=dtype_val) - target_device = tensors_arg[self.base_model_ref].device; target_dtype = tensors_arg[self.base_model_ref].dtype - merged_tensor_accumulator = tensors_arg[self.base_model_ref].clone().to(device=target_device,dtype=target_dtype) + target_device = tensors_arg[self.base_model_ref].device + target_dtype = tensors_arg[self.base_model_ref].dtype + merged_tensor_accumulator = ( + tensors_arg[self.base_model_ref] + .clone() + .to(device=target_device, dtype=target_dtype) + ) ordered_model_refs_for_ca: List[ModelReference] = [] - model_ref_by_string_id: Dict[str, ModelReference] = {str(ref.model): ref for ref in tensors_arg.keys()} - + model_ref_by_string_id: Dict[str, ModelReference] = { + str(ref.model): ref for ref in tensors_arg.keys() + } + current_pruning_order_strings = self.global_pruning_order if current_pruning_order_strings: for id_str_in_order in current_pruning_order_strings: - if id_str_in_order == base_model_identifier_str: continue - if id_str_in_order in model_ref_by_string_id: ordered_model_refs_for_ca.append(model_ref_by_string_id[id_str_in_order]) - else: logging.warning(f"Model ID '{id_str_in_order}' from order not found for '{self.current_weight_info.name}'.") + if id_str_in_order == base_model_identifier_str: + continue + if id_str_in_order in model_ref_by_string_id: + ordered_model_refs_for_ca.append( + model_ref_by_string_id[id_str_in_order] + ) + else: + logging.warning( + f"Model ID '{id_str_in_order}' from order not found for '{self.current_weight_info.name}'." + ) else: - sorted_non_base_string_ids = sorted([str(ref.model) for ref in tensors_arg.keys() if str(ref.model) != base_model_identifier_str]) + sorted_non_base_string_ids = sorted( + [ + str(ref.model) + for ref in tensors_arg.keys() + if str(ref.model) != base_model_identifier_str + ] + ) for id_str in sorted_non_base_string_ids: - if id_str in model_ref_by_string_id: ordered_model_refs_for_ca.append(model_ref_by_string_id[id_str]) - + if id_str in model_ref_by_string_id: + ordered_model_refs_for_ca.append(model_ref_by_string_id[id_str]) + if not ordered_model_refs_for_ca: - logging.info(f"No non-base models for '{self.current_weight_info.name}'. Returning base."); return merged_tensor_accumulator - cumulative_param_mask = torch.zeros_like(merged_tensor_accumulator,dtype=torch.bool,device=target_device) - + logging.info( + f"No non-base models for '{self.current_weight_info.name}'. Returning base." + ) + return merged_tensor_accumulator + cumulative_param_mask = torch.zeros_like( + merged_tensor_accumulator, dtype=torch.bool, device=target_device + ) + for model_ref_current in ordered_model_refs_for_ca: - if model_ref_current not in tensors_arg: logging.warning(f"Tensor for '{str(model_ref_current.model)}' unavailable for '{self.current_weight_info.name}'."); continue - fine_tuned_tensor_val = tensors_arg[model_ref_current].to(device=target_device,dtype=target_dtype) - base_tensor_for_diff = tensors_arg[self.base_model_ref].to(device=target_device,dtype=target_dtype) + if model_ref_current not in tensors_arg: + logging.warning( + f"Tensor for '{str(model_ref_current.model)}' unavailable for '{self.current_weight_info.name}'." + ) + continue + fine_tuned_tensor_val = tensors_arg[model_ref_current].to( + device=target_device, dtype=target_dtype + ) + base_tensor_for_diff = tensors_arg[self.base_model_ref].to( + device=target_device, dtype=target_dtype + ) scaling_coefficient = 1.0 if model_ref_current in self.per_model_tensor_params: inner_params = self.per_model_tensor_params[model_ref_current] - if "weight" in inner_params: + if "weight" in inner_params: val = inner_params["weight"] - if isinstance(val, (int,float)): scaling_coefficient = float(val) - else: logging.warning(f"Expected float for per-model weight, got {type(val)}") + if isinstance(val, (int, float)): + scaling_coefficient = float(val) + else: + logging.warning( + f"Expected float for per-model weight, got {type(val)}" + ) - n_val_current,m_val_current = self._resolve_n_and_m_for_model(model_ref_current) + n_val_current, m_val_current = self._resolve_n_and_m_for_model( + model_ref_current + ) task_vector_val = fine_tuned_tensor_val - base_tensor_for_diff available_params_mask = ~cumulative_param_mask - candidate_task_vector = task_vector_val * available_params_mask.to(task_vector_val.dtype) - pruned_task_vector,newly_retained_mask = prune_n_m_structural(candidate_task_vector,n_val_current,m_val_current) - merged_tensor_accumulator += scaling_coefficient * pruned_task_vector.to(merged_tensor_accumulator.dtype) - cumulative_param_mask = torch.logical_or(cumulative_param_mask,newly_retained_mask.to(device=cumulative_param_mask.device)) + candidate_task_vector = task_vector_val * available_params_mask.to( + task_vector_val.dtype + ) + pruned_task_vector, newly_retained_mask = prune_n_m_structural( + candidate_task_vector, n_val_current, m_val_current + ) + merged_tensor_accumulator += scaling_coefficient * pruned_task_vector.to( + merged_tensor_accumulator.dtype + ) + cumulative_param_mask = torch.logical_or( + cumulative_param_mask, + newly_retained_mask.to(device=cumulative_param_mask.device), + ) return merged_tensor_accumulator @override diff --git a/mergekit/merge_methods/registry.py b/mergekit/merge_methods/registry.py index 3a9f3b56..a7a83db0 100644 --- a/mergekit/merge_methods/registry.py +++ b/mergekit/merge_methods/registry.py @@ -5,11 +5,11 @@ from mergekit.merge_methods.arcee_fusion import ArceeFusionMerge from mergekit.merge_methods.base import MergeMethod +from mergekit.merge_methods.cabs import CABSMerge from mergekit.merge_methods.generalized_task_arithmetic import ( ConsensusMethod, GeneralizedTaskArithmeticMerge, ) -from mergekit.merge_methods.cabs import CABSMerge from mergekit.merge_methods.karcher import KarcherMerge from mergekit.merge_methods.linear import LinearMerge from mergekit.merge_methods.model_stock import ModelStockMerge