From aa64f01f21e9c76410334cb5de36736f0c8e60aa Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 29 Jul 2024 19:54:48 -0400 Subject: [PATCH 1/7] Tokenizer correction fro activations extractor script --- mergekit/scripts/ABM/extract_activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergekit/scripts/ABM/extract_activations.py b/mergekit/scripts/ABM/extract_activations.py index 7cb5961b..c8fbc4c0 100644 --- a/mergekit/scripts/ABM/extract_activations.py +++ b/mergekit/scripts/ABM/extract_activations.py @@ -221,7 +221,7 @@ def main( logging.info("Using chat template for inference") tokenize_function = lambda x: tokenizer.apply_chat_template( x, - padding="longest", + padding="max_length", max_length=max_length, truncation=True, return_dict=True, @@ -230,7 +230,7 @@ def main( logging.info("Using default tokenizer (no chat template) for inference") tokenize_function = lambda x: tokenizer( x, - padding="longest", + padding="max_length", max_length=max_length, truncation=True, ) From c57dd5e09daeac4349e7f8725735e67e7b8d97d3 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 29 Jul 2024 20:30:19 -0400 Subject: [PATCH 2/7] Code simplification --- .../scripts/ABM/activations_based_merge.py | 13 +++++-------- .../ABM/extract_permutation_matrices.py | 19 ++++++------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index cb3c912a..efc8c652 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -16,8 +16,8 @@ @click.command("mergekit-activation-based-merge") -@click.argument("model_path", type=str) -@click.argument("secondary_model_path", type=str) +@click.argument("model_path", type=str, help="Path to the anchor model") +@click.argument("secondary_model_path", type=str, help="Path to the secondary model") @click.argument("merge_unmerge_directory", type=str) @click.option("--out-path", "-o", required=True, type=str, help="Output model path") @click.option( @@ -121,8 +121,8 @@ def main( if merge_matrix is not None: if weight_info.is_embed: - w = (merge_matrix[0] @ w.T).T - w2 = (merge_matrix[1] @ w2.T).T + w = w @ merge_matrix[0].T + w2 = w2 @ merge_matrix[1].T else: w = merge_matrix[0] @ w w2 = merge_matrix[1] @ w2 @@ -151,10 +151,7 @@ def main( ) # average weights and save them - if merge_matrix: - w = w + w2 - else: - w = (w + w2) / 2 + w = (w + w2) / 2 writer.save_tensor(weight_info.name, w) writer.finalize() diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py index 75c58692..91af9dfc 100644 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ b/mergekit/scripts/ABM/extract_permutation_matrices.py @@ -39,14 +39,10 @@ def match_tensors_permute( new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] mats.append(new_mat.T) - unmerge_mats = mats + unmerge = torch.cat(mats, dim=0) + merge = unmerge.clone().T - unmerge = torch.cat(unmerge_mats, dim=0) - - merge = torch.cat(mats, dim=0) - merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) - - return merge.T, unmerge + return merge, unmerge def match_tensors_permute_MHA( @@ -111,13 +107,10 @@ def match_tensors_permute_MHA( ] mats.append(new_mat.T) - unmerge_mats = mats - - unmerge = torch.cat(unmerge_mats, dim=0) - merge = torch.cat(mats, dim=0) - merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) + unmerge = torch.cat(mats, dim=0) + merge = unmerge.clone().T - return merge.T, unmerge + return merge, unmerge @click.command("mergekit-abm-extract-permutations") From db230c430b1002933a567e4d9953f468a98162ca Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 29 Jul 2024 22:30:19 -0400 Subject: [PATCH 3/7] Remove help statements --- mergekit/scripts/ABM/activations_based_merge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index efc8c652..29e51571 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -16,8 +16,8 @@ @click.command("mergekit-activation-based-merge") -@click.argument("model_path", type=str, help="Path to the anchor model") -@click.argument("secondary_model_path", type=str, help="Path to the secondary model") +@click.argument("model_path", type=str) +@click.argument("secondary_model_path", type=str) @click.argument("merge_unmerge_directory", type=str) @click.option("--out-path", "-o", required=True, type=str, help="Output model path") @click.option( From ca71aa3b361f82f29799e627daef0cc830474c1e Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Thu, 8 Aug 2024 01:09:45 -0400 Subject: [PATCH 4/7] Only output aligned model --- .../scripts/ABM/activations_based_merge.py | 57 +++++++------------ .../ABM/extract_permutation_matrices.py | 11 +--- 2 files changed, 21 insertions(+), 47 deletions(-) diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index 29e51571..984ffe17 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -15,11 +15,12 @@ from mergekit.options import MergeOptions, add_merge_options -@click.command("mergekit-activation-based-merge") -@click.argument("model_path", type=str) +@click.command("mergekit-activation-based-align") @click.argument("secondary_model_path", type=str) @click.argument("merge_unmerge_directory", type=str) -@click.option("--out-path", "-o", required=True, type=str, help="Output model path") +@click.option( + "--out-path", "-o", required=True, type=str, help="Path to save the aligned model" +) @click.option( "--dtype", type=str, @@ -35,7 +36,6 @@ ) @add_merge_options def main( - model_path: str, secondary_model_path, merge_unmerge_directory: str, out_path: str, @@ -43,7 +43,6 @@ def main( device: Optional[str], merge_options: MergeOptions, ): - model = ModelReference.model_validate(model_path) secondary_model = ModelReference.model_validate(secondary_model_path) dtype = dtype_from_name(dtype) if dtype else None @@ -52,8 +51,7 @@ def main( cache.lazy_unpickle = merge_options.lazy_unpickle cache.hf_cache_dir = merge_options.transformers_cache - for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"): - cache.get(m) + cache.get(secondary_model) writer = TensorWriter( out_path=out_path, @@ -61,18 +59,19 @@ def main( safe_serialization=merge_options.safe_serialization, ) - model_config = model.config(trust_remote_code=merge_options.trust_remote_code) + model_config = secondary_model.config( + trust_remote_code=merge_options.trust_remote_code + ) model_arch_info = get_architecture_info( - model.config(trust_remote_code=merge_options.trust_remote_code) + secondary_model.config(trust_remote_code=merge_options.trust_remote_code) ) - loader_1 = cache.get(model) - loader_2 = cache.get(secondary_model) + loader = cache.get(secondary_model) os.makedirs(out_path, exist_ok=True) merge_unmerge_dictionary = {} - # load files from merge_unmerge_directory + spaces = [ f.split("_unmerge")[0] for f in os.listdir(merge_unmerge_directory) @@ -98,68 +97,50 @@ def main( if weight_info.input_space in merge_unmerge_dictionary: _, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space] - unmerge_matrix = unmerge_matrix.chunk(2, dim=0) if weight_info.output_space in merge_unmerge_dictionary: merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space] - merge_matrix = merge_matrix.chunk(2, dim=1) - original_w = loader_1.get_tensor(weight_info.name, device=device) - original_w2 = loader_2.get_tensor(weight_info.name, device=device) + original_w = loader.get_tensor(weight_info.name, device=device) if dtype is not None: original_w = original_w.to(dtype=dtype) original_w2 = original_w2.to(dtype=dtype) w = torch.clone(original_w) - w2 = torch.clone(original_w2) if not merge_matrix and not unmerge_matrix: logging.warning( - f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix" + f"❌ Weight {weight_info.name} for model has no merge or unmerge matrix !!" ) if merge_matrix is not None: if weight_info.is_embed: w = w @ merge_matrix[0].T - w2 = w2 @ merge_matrix[1].T else: w = merge_matrix[0] @ w - w2 = merge_matrix[1] @ w2 if unmerge_matrix is not None: w = w @ unmerge_matrix[0] - w2 = w2 @ unmerge_matrix[1] - # check if weights have not mutated, if yes then shoot warning if torch.allclose(original_w, w): logging.warning( - f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge" + f"❌ Weight {weight_info.name} for input model has NOT mutated during merge" ) else: logging.warning( - f"✅ Weight {weight_info.name} for model 1 has mutated during merge" + f"✅ Weight {weight_info.name} for input model has mutated during merge" ) - if torch.allclose(original_w2, w2): - logging.warning( - f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge" - ) - else: - logging.warning( - f"✅ Weight {weight_info.name} for model 2 has mutated during merge" - ) - - # average weights and save them - w = (w + w2) / 2 writer.save_tensor(weight_info.name, w) writer.finalize() - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(secondary_model_path) tokenizer.save_pretrained(out_path, safe_serialization=True) - # write config - model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code) + model_out_config = secondary_model.config( + trust_remote_code=merge_options.trust_remote_code + ) if dtype: model_out_config.torch_dtype = dtype model_out_config.save_pretrained(out_path) diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py index 91af9dfc..7ec45d23 100644 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ b/mergekit/scripts/ABM/extract_permutation_matrices.py @@ -29,17 +29,13 @@ def match_tensors_permute( Om = correlation_matrix.shape[0] // 2 device = correlation_matrix.device - mats = [torch.eye(Om, device=device)] - corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy() if absval: corr_submatrix = np.absolute(corr_submatrix) _, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) - new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - mats.append(new_mat.T) + unmerge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - unmerge = torch.cat(mats, dim=0) merge = unmerge.clone().T return merge, unmerge @@ -59,7 +55,6 @@ def match_tensors_permute_MHA( device = correlation_matrix.device query_size = Om // n_heads - mats = [torch.eye(Om, device=device)] head_perms = [] costs = np.ones((n_heads, n_heads)) * -sys.maxsize @@ -102,12 +97,10 @@ def match_tensors_permute_MHA( head_perm = col_inds_storage[head_1][head_2] head_perms.append(torch.tensor(head_perm + query_size * head_2)) - new_mat = torch.eye(Om, device=device)[ + unmerge = torch.eye(Om, device=device)[ torch.cat(head_perms).clone().detach().long().to(device) ] - mats.append(new_mat.T) - unmerge = torch.cat(mats, dim=0) merge = unmerge.clone().T return merge, unmerge From 17c0d44f3c525180acfe8325b8dfb9e253f0eaf3 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Thu, 8 Aug 2024 15:31:06 -0400 Subject: [PATCH 5/7] More corrections --- mergekit/scripts/ABM/activations_based_merge.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index 984ffe17..70b8e8b6 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -105,23 +105,22 @@ def main( if dtype is not None: original_w = original_w.to(dtype=dtype) - original_w2 = original_w2.to(dtype=dtype) w = torch.clone(original_w) - if not merge_matrix and not unmerge_matrix: + if merge_matrix is None and unmerge_matrix is None: logging.warning( f"❌ Weight {weight_info.name} for model has no merge or unmerge matrix !!" ) if merge_matrix is not None: if weight_info.is_embed: - w = w @ merge_matrix[0].T + w = w @ merge_matrix.T else: - w = merge_matrix[0] @ w + w = merge_matrix @ w if unmerge_matrix is not None: - w = w @ unmerge_matrix[0] + w = w @ unmerge_matrix if torch.allclose(original_w, w): logging.warning( From 876fbbbc69feacceba947644a9951f0878016b6d Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Thu, 8 Aug 2024 16:23:38 -0400 Subject: [PATCH 6/7] More corrections --- mergekit/scripts/ABM/extract_permutation_matrices.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py index 7ec45d23..db10706a 100644 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ b/mergekit/scripts/ABM/extract_permutation_matrices.py @@ -34,9 +34,9 @@ def match_tensors_permute( corr_submatrix = np.absolute(corr_submatrix) _, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) - unmerge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] + merge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - merge = unmerge.clone().T + unmerge = merge.clone().T return merge, unmerge @@ -97,11 +97,11 @@ def match_tensors_permute_MHA( head_perm = col_inds_storage[head_1][head_2] head_perms.append(torch.tensor(head_perm + query_size * head_2)) - unmerge = torch.eye(Om, device=device)[ + merge = torch.eye(Om, device=device)[ torch.cat(head_perms).clone().detach().long().to(device) ] - merge = unmerge.clone().T + unmerge = merge.clone().T return merge, unmerge From 08eafe2c727402be4adde07fb3575d6a86756ddf Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 10 Sep 2024 00:12:36 -0400 Subject: [PATCH 7/7] Switch to inline calc of unmerge as opposed to storing it --- .../scripts/ABM/activations_based_merge.py | 17 +++++------------ .../ABM/extract_permutation_matrices.py | 19 +++++-------------- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py index 70b8e8b6..5c38050a 100644 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ b/mergekit/scripts/ABM/activations_based_merge.py @@ -78,28 +78,21 @@ def main( if "_unmerge" in f ] for i in spaces: - logging.info(f"Loading merge/unmerge tensors for {i}") + logging.info(f"Loading merge tensors for {i}") m = safetensors.torch.load_file( os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"), device=device, ) - u = safetensors.torch.load_file( - os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"), - device=device, - ) - merge_unmerge_dictionary[i] = ( - m[i].to(device, dtype=dtype), - u[i].to(device, dtype=dtype), - ) + merge_unmerge_dictionary[i] = m[i].to(device, dtype=dtype) for weight_info in model_arch_info.all_weights(config=model_config): - merge_matrix, unmerge_matrix = None, None + merge_matrix = None if weight_info.input_space in merge_unmerge_dictionary: - _, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space] + unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space].t() if weight_info.output_space in merge_unmerge_dictionary: - merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space] + merge_matrix = merge_unmerge_dictionary[weight_info.output_space] original_w = loader.get_tensor(weight_info.name, device=device) diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py index db10706a..23fad64b 100644 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ b/mergekit/scripts/ABM/extract_permutation_matrices.py @@ -36,9 +36,7 @@ def match_tensors_permute( merge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - unmerge = merge.clone().T - - return merge, unmerge + return merge def match_tensors_permute_MHA( @@ -101,9 +99,7 @@ def match_tensors_permute_MHA( torch.cat(head_perms).clone().detach().long().to(device) ] - unmerge = merge.clone().T - - return merge, unmerge + return merge @click.command("mergekit-abm-extract-permutations") @@ -183,14 +179,14 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device): correlation_matrix = calc_correlation_matrix(concatenated_feature) if feature_space in (kq_spaces + v_spaces): - merge, unmerge = match_tensors_permute_MHA( + merge = match_tensors_permute_MHA( correlation_matrix=correlation_matrix, n_heads=model_config.num_attention_heads, absval=absval, ) else: - merge, unmerge = match_tensors_permute( + merge = match_tensors_permute( correlation_matrix=correlation_matrix, absval=absval, ) @@ -200,12 +196,7 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device): f"{out_path}/{feature_space}_merge.safetensor", ) - safetensors.torch.save_file( - {feature_space: unmerge.contiguous()}, - f"{out_path}/{feature_space}_unmerge.safetensor", - ) - - del merge, unmerge, correlation_matrix, concatenated_feature + del merge, correlation_matrix, concatenated_feature if __name__ == "__main__":