Skip to content

Commit

Permalink
Switch to inline calc of unmerge as opposed to storing it
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Sep 10, 2024
1 parent 876fbbb commit 08eafe2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 26 deletions.
17 changes: 5 additions & 12 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,21 @@ def main(
if "_unmerge" in f
]
for i in spaces:
logging.info(f"Loading merge/unmerge tensors for {i}")
logging.info(f"Loading merge tensors for {i}")
m = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"),
device=device,
)
u = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"),
device=device,
)
merge_unmerge_dictionary[i] = (
m[i].to(device, dtype=dtype),
u[i].to(device, dtype=dtype),
)
merge_unmerge_dictionary[i] = m[i].to(device, dtype=dtype)

for weight_info in model_arch_info.all_weights(config=model_config):
merge_matrix, unmerge_matrix = None, None
merge_matrix = None

if weight_info.input_space in merge_unmerge_dictionary:
_, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space]
unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space].t()

if weight_info.output_space in merge_unmerge_dictionary:
merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space]
merge_matrix = merge_unmerge_dictionary[weight_info.output_space]

original_w = loader.get_tensor(weight_info.name, device=device)

Expand Down
19 changes: 5 additions & 14 deletions mergekit/scripts/ABM/extract_permutation_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def match_tensors_permute(

merge = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)]

unmerge = merge.clone().T

return merge, unmerge
return merge


def match_tensors_permute_MHA(
Expand Down Expand Up @@ -101,9 +99,7 @@ def match_tensors_permute_MHA(
torch.cat(head_perms).clone().detach().long().to(device)
]

unmerge = merge.clone().T

return merge, unmerge
return merge


@click.command("mergekit-abm-extract-permutations")
Expand Down Expand Up @@ -183,14 +179,14 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device):
correlation_matrix = calc_correlation_matrix(concatenated_feature)

if feature_space in (kq_spaces + v_spaces):
merge, unmerge = match_tensors_permute_MHA(
merge = match_tensors_permute_MHA(
correlation_matrix=correlation_matrix,
n_heads=model_config.num_attention_heads,
absval=absval,
)

else:
merge, unmerge = match_tensors_permute(
merge = match_tensors_permute(
correlation_matrix=correlation_matrix,
absval=absval,
)
Expand All @@ -200,12 +196,7 @@ def main(model1_ft, model2_ft, model_path, out_path, absval, device):
f"{out_path}/{feature_space}_merge.safetensor",
)

safetensors.torch.save_file(
{feature_space: unmerge.contiguous()},
f"{out_path}/{feature_space}_unmerge.safetensor",
)

del merge, unmerge, correlation_matrix, concatenated_feature
del merge, correlation_matrix, concatenated_feature


if __name__ == "__main__":
Expand Down

0 comments on commit 08eafe2

Please sign in to comment.