diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index cb3c912a..5c38050a 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -15,11 +15,12 @@ from mergekit.options import MergeOptions, add_merge_options -@click.command("mergekit-activation-based-merge") -@click.argument("model_path", type=str) +@click.command("mergekit-activation-based-align") @click.argument("secondary_model_path", type=str) @click.argument("merge_unmerge_directory", type=str) -@click.option("--out-path", "-o", required=True, type=str, help="Output model path") +@click.option( + "--out-path", "-o", required=True, type=str, help="Path to save the aligned model" +) @click.option( "--dtype", type=str, @@ -35,7 +36,6 @@ ) @add_merge_options def main( - model_path: str, secondary_model_path, merge_unmerge_directory: str, out_path: str, @@ -43,7 +43,6 @@ def main( device: Optional[str], merge_options: MergeOptions, ): - model = ModelReference.model_validate(model_path) secondary_model = ModelReference.model_validate(secondary_model_path) dtype = dtype_from_name(dtype) if dtype else None @@ -52,8 +51,7 @@ def main( cache.lazy_unpickle = merge_options.lazy_unpickle cache.hf_cache_dir = merge_options.transformers_cache - for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"): - cache.get(m) + cache.get(secondary_model) writer = TensorWriter( out_path=out_path, @@ -61,108 +59,80 @@ def main( safe_serialization=merge_options.safe_serialization, ) - model_config = model.config(trust_remote_code=merge_options.trust_remote_code) + model_config = secondary_model.config( + trust_remote_code=merge_options.trust_remote_code + ) model_arch_info = get_architecture_info( - model.config(trust_remote_code=merge_options.trust_remote_code) + secondary_model.config(trust_remote_code=merge_options.trust_remote_code) ) - loader_1 = cache.get(model) - loader_2 = cache.get(secondary_model) + loader = cache.get(secondary_model) os.makedirs(out_path, exist_ok=True) merge_unmerge_dictionary = {} - # load files from merge_unmerge_directory + spaces = [ f.split("_unmerge")[0] for f in os.listdir(merge_unmerge_directory) if "_unmerge" in f ] for i in spaces: - logging.info(f"Loading merge/unmerge tensors for {i}") + logging.info(f"Loading merge tensors for {i}") m = safetensors.torch.load_file( os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"), device=device, ) - u = safetensors.torch.load_file( - os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"), - device=device, - ) - merge_unmerge_dictionary[i] = ( - m[i].to(device, dtype=dtype), - u[i].to(device, dtype=dtype), - ) + merge_unmerge_dictionary[i] = m[i].to(device, dtype=dtype) for weight_info in model_arch_info.all_weights(config=model_config): - merge_matrix, unmerge_matrix = None, None + merge_matrix = None if weight_info.input_space in merge_unmerge_dictionary: - _, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space] - unmerge_matrix = unmerge_matrix.chunk(2, dim=0) + unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space].t() if weight_info.output_space in merge_unmerge_dictionary: - merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space] - merge_matrix = merge_matrix.chunk(2, dim=1) + merge_matrix = merge_unmerge_dictionary[weight_info.output_space] - original_w = loader_1.get_tensor(weight_info.name, device=device) - original_w2 = loader_2.get_tensor(weight_info.name, device=device) + original_w = loader.get_tensor(weight_info.name, device=device) if dtype is not None: original_w = original_w.to(dtype=dtype) - original_w2 = original_w2.to(dtype=dtype) w = torch.clone(original_w) - w2 = torch.clone(original_w2) - if not merge_matrix and not unmerge_matrix: + if merge_matrix is None and unmerge_matrix is None: logging.warning( - f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix" + f"❌ Weight {weight_info.name} for model has no merge or unmerge matrix !!" ) if merge_matrix is not None: if weight_info.is_embed: - w = (merge_matrix[0] @ w.T).T - w2 = (merge_matrix[1] @ w2.T).T + w = w @ merge_matrix.T else: - w = merge_matrix[0] @ w - w2 = merge_matrix[1] @ w2 + w = merge_matrix @ w if unmerge_matrix is not None: - w = w @ unmerge_matrix[0] - w2 = w2 @ unmerge_matrix[1] + w = w @ unmerge_matrix - # check if weights have not mutated, if yes then shoot warning if torch.allclose(original_w, w): logging.warning( - f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge" + f"❌ Weight {weight_info.name} for input model has NOT mutated during merge" ) else: logging.warning( - f"✅ Weight {weight_info.name} for model 1 has mutated during merge" + f"✅ Weight {weight_info.name} for input model has mutated during merge" ) - if torch.allclose(original_w2, w2): - logging.warning( - f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge" - ) - else: - logging.warning( - f"✅ Weight {weight_info.name} for model 2 has mutated during merge" - ) - - # average weights and save them - if merge_matrix: - w = w + w2 - else: - w = (w + w2) / 2 writer.save_tensor(weight_info.name, w) writer.finalize() - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(secondary_model_path) tokenizer.save_pretrained(out_path, safe_serialization=True) - # write config - model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code) + model_out_config = secondary_model.config( + trust_remote_code=merge_options.trust_remote_code + ) if dtype: model_out_config.torch_dtype = dtype model_out_config.save_pretrained(out_path) diff --git a/mergekit/scripts/ABM/extract_activations.py b/mergekit/scripts/ABM/extract_activations.py index 7cb5961b..c8fbc4c0 100644 --- a/mergekit/scripts/ABM/extract_activations.py +++ b/mergekit/scripts/ABM/extract_activations.py @@ -221,7 +221,7 @@ def main( logging.info("Using chat template for inference") tokenize_function = lambda x: tokenizer.apply_chat_template( x, - padding="longest", + padding="max_length", max_length=max_length, truncation=True, return_dict=True, @@ -230,7 +230,7 @@ def main( logging.info("Using default tokenizer (no chat template) for inference") tokenize_function = lambda x: tokenizer( x, - padding="longest", + padding="max_length", max_length=max_length, truncation=True, ) diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py index 75c58692..23fad64b 100644 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ b/mergekit/scripts/ABM/extract_permutation_matrices.py @@ -29,24 +29,14 @@ def match_tensors_permute( Om = correlation_matrix.shape[0] // 2 device = correlation_matrix.device - mats = [torch.eye(Om, device=device)] - corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy() if absval: corr_submatrix = np.absolute(corr_submatrix) _, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) - new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - mats.append(new_mat.T) - - unmerge_mats = mats - - unmerge = torch.cat(unmerge_mats, dim=0) + merge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - merge = torch.cat(mats, dim=0) - merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) - - return merge.T, unmerge + return merge def match_tensors_permute_MHA( @@ -63,7 +53,6 @@ def match_tensors_permute_MHA( device = correlation_matrix.device query_size = Om // n_heads - mats = [torch.eye(Om, device=device)] head_perms = [] costs = np.ones((n_heads, n_heads)) * -sys.maxsize @@ -106,18 +95,11 @@ def match_tensors_permute_MHA( head_perm = col_inds_storage[head_1][head_2] head_perms.append(torch.tensor(head_perm + query_size * head_2)) - new_mat = torch.eye(Om, device=device)[ + merge = torch.eye(Om, device=device)[ torch.cat(head_perms).clone().detach().long().to(device) ] - mats.append(new_mat.T) - - unmerge_mats = mats - unmerge = torch.cat(unmerge_mats, dim=0) - merge = torch.cat(mats, dim=0) - merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) - - return merge.T, unmerge + return merge @click.command("mergekit-abm-extract-permutations") @@ -197,14 +179,14 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device): correlation_matrix = calc_correlation_matrix(concatenated_feature) if feature_space in (kq_spaces + v_spaces): - merge, unmerge = match_tensors_permute_MHA( + merge = match_tensors_permute_MHA( correlation_matrix=correlation_matrix, n_heads=model_config.num_attention_heads, absval=absval, ) else: - merge, unmerge = match_tensors_permute( + merge = match_tensors_permute( correlation_matrix=correlation_matrix, absval=absval, ) @@ -214,12 +196,7 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device): f"{out_path}/{feature_space}_merge.safetensor", ) - safetensors.torch.save_file( - {feature_space: unmerge.contiguous()}, - f"{out_path}/{feature_space}_unmerge.safetensor", - ) - - del merge, unmerge, correlation_matrix, concatenated_feature + del merge, correlation_matrix, concatenated_feature if __name__ == "__main__":