From c2b0f066430e61aab83f6bfe77c6a5efd912eea4 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Thu, 30 May 2024 15:32:58 +0100 Subject: [PATCH 01/64] introduced metric_methods, closely following implementation of merge_methods --- mergekit/metric_methods/MSE.py | 87 +++++++++++++++++++ mergekit/metric_methods/PCA_rank.py | 125 ++++++++++++++++++++++++++++ mergekit/metric_methods/SMAPE.py | 91 ++++++++++++++++++++ mergekit/metric_methods/__init__.py | 46 ++++++++++ mergekit/metric_methods/base.py | 29 +++++++ mergekit/metric_methods/cossim.py | 79 ++++++++++++++++++ mergekit/metric_methods/scale.py | 93 +++++++++++++++++++++ 7 files changed, 550 insertions(+) create mode 100644 mergekit/metric_methods/MSE.py create mode 100644 mergekit/metric_methods/PCA_rank.py create mode 100644 mergekit/metric_methods/SMAPE.py create mode 100644 mergekit/metric_methods/__init__.py create mode 100644 mergekit/metric_methods/base.py create mode 100644 mergekit/metric_methods/cossim.py create mode 100644 mergekit/metric_methods/scale.py diff --git a/mergekit/metric_methods/MSE.py b/mergekit/metric_methods/MSE.py new file mode 100644 index 00000000..6d36089e --- /dev/null +++ b/mergekit/metric_methods/MSE.py @@ -0,0 +1,87 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod + +class MSEMetricTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + keys = list(tensors.keys()) + tensors = [tensors[key] for key in keys] + + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError( + f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" + ) + if len(tensors) != 2: + raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") + + if 'mlp' not in self.weight_info.name: + return + + res = {} + # pairwise similarity of corresponding rows in weights matrix + + # Ensure the tensors have the same shape + assert tensors[0].shape == tensors[0].shape, "Tensors must have the same shape" + + # Compute the squared differences + squared_diff = (tensors[0] - tensors[1]) ** 2 + + + # Compute the mean of squared differences for each row + mse_per_neuron = torch.mean(squared_diff, dim=1) + + res['MSE_full'] = mse_per_neuron + res['MSE_mean'] = mse_per_neuron.mean() + res['MSE_std'] = mse_per_neuron.std() + + return res + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +class MSEMetric(MetricMethod): + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: GatherTensors, + **_kwargs, + ) -> Task: + return MSEMetricTask( + gather_tensors=tensors, + weight_info=output_weight, + ) diff --git a/mergekit/metric_methods/PCA_rank.py b/mergekit/metric_methods/PCA_rank.py new file mode 100644 index 00000000..f268bed6 --- /dev/null +++ b/mergekit/metric_methods/PCA_rank.py @@ -0,0 +1,125 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod +from mergekit.merge_methods.base import ConfigParameterDef +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes + +import torch.nn.functional as F + +def pca_components_for_variance(X, variance_threshold=0.99, rescale=True): + """ + Compute the number of principal components required to explain + at least `variance_threshold` of the total variance in the dataset X using PyTorch. + + Args: + X (torch.Tensor): The data matrix. Rows are samples and columns are features. + variance_threshold (float): The fraction of total variance that we want to capture. + + Returns: + int: The number of principal components required to capture the specified variance threshold. + """ + # Standardize the data (mean 0 and variance 1) + X_mean = torch.mean(X, dim=0) + X_std = torch.std(X, dim=0, unbiased=False) + X = X - X_mean + + if rescale: + X = X / X_std + + # Compute the covariance matrix + covariance_matrix = torch.mm(X.T, X) / (X.shape[0] - 1) + + # Perform SVD on the covariance matrix + U, S, V = torch.svd(covariance_matrix) + + # Calculate explained variance ratios + explained_variance_ratio = S / torch.sum(S) + cumsum_variance = torch.cumsum(explained_variance_ratio, dim=0) + + # Determine the number of components needed to surpass the variance threshold + num_components = torch.where(cumsum_variance >= variance_threshold)[0][0] + 1 + + return num_components.item() + + +class PCA_RankTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] + normalize: bool + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + keys = list(tensors.keys()) + + tensors = [tensors[key] for key in keys] + + + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError( + f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" + ) + if len(tensors) != 1: + raise RuntimeError(f"Expected 1 tensors, got {len(tensors)}") + + if 'mlp' not in self.weight_info.name: + return + + res = {} + X = tensors[0] + + res['num_components_99'] = pca_components_for_variance(X, variance_threshold=0.99, rescale=True) + res['num_components_95'] = pca_components_for_variance(X, variance_threshold=0.95, rescale=True) + return res + + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +class PCA_RankMetric(MetricMethod): + + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: GatherTensors, + parameters: Dict[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], + **_kwargs, + ) -> Task: + return PCA_RankTask( + gather_tensors=tensors, + tensor_parameters=tensor_parameters, + normalize=parameters["normalize"], + weight_info=output_weight, + ) diff --git a/mergekit/metric_methods/SMAPE.py b/mergekit/metric_methods/SMAPE.py new file mode 100644 index 00000000..53110d35 --- /dev/null +++ b/mergekit/metric_methods/SMAPE.py @@ -0,0 +1,91 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod + +import torch.nn.functional as F + +class SMAPEMetricTask(Task[torch.Tensor]): + """ + Symmetric Mean Absolute Percentage Error (SMAPE) + + SMAPE = 100 * |y - y_hat| / ((|y| + |y_hat|) / 2) + """ + gather_tensors: GatherTensors + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + keys = list(tensors.keys()) + tensors = [tensors[key] for key in keys] + + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError( + f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" + ) + if len(tensors) != 2: + raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") + + if 'mlp' not in self.weight_info.name: + return + + res = {} + + # Ensure the tensors have the same shape + assert tensors[0].shape == tensors[0].shape, "Tensors must have the same shape" + + # SMAPE + numerator = torch.abs(tensors[0] - tensors[1]) + denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) / 2 + smape = 100 * torch.mean(torch.div(numerator, denominator), dim=1) + + res['SMAPE_full'] = smape + res['SMAPE_mean'] = smape.mean() + res['SMAPE_std'] = smape.std() + + return res + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +class SMAPEMetric(MetricMethod): + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: GatherTensors, + **_kwargs, + ) -> Task: + return SMAPEMetricTask( + gather_tensors=tensors, + weight_info=output_weight, + ) diff --git a/mergekit/metric_methods/__init__.py b/mergekit/metric_methods/__init__.py new file mode 100644 index 00000000..e3e46ad2 --- /dev/null +++ b/mergekit/metric_methods/__init__.py @@ -0,0 +1,46 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from mergekit.metric_methods.base import MetricMethod +from mergekit.metric_methods.cossim import CossimMetric +from mergekit.metric_methods.PCA_rank import PCA_RankMetric +from mergekit.metric_methods.MSE import MSEMetric +from mergekit.metric_methods.SMAPE import SMAPEMetric +from mergekit.metric_methods.scale import ScaleMetric + + +def get(method: str) -> MetricMethod: + if method == "cossim": + return CossimMetric() + elif method == "PCA_rank": + return PCA_RankMetric() + elif method == "MSE": + return MSEMetric() + elif method == "SMAPE": + return SMAPEMetric() + elif method == "scale": + return ScaleMetric() + raise RuntimeError(f"Unimplemented merge method {method}") + + +__all__ = [ + "MetricMethod", + "get", + "CossimMetric", + "MSEMetric", + "SMAPEMetric", + "ScaleMetric", + "PCA_RankMetric", +] diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py new file mode 100644 index 00000000..df42a39c --- /dev/null +++ b/mergekit/metric_methods/base.py @@ -0,0 +1,29 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from pydantic import BaseModel + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors + +from mergekit.merge_methods.base import MergeMethod, ConfigParameterDef + +class MetricMethod(MergeMethod): + pass \ No newline at end of file diff --git a/mergekit/metric_methods/cossim.py b/mergekit/metric_methods/cossim.py new file mode 100644 index 00000000..598c8035 --- /dev/null +++ b/mergekit/metric_methods/cossim.py @@ -0,0 +1,79 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod + +import torch.nn.functional as F + +class CossimMetricTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + keys = list(tensors.keys()) + tensors = [tensors[key] for key in keys] + + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError( + f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" + ) + if len(tensors) != 2: + raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") + + if 'mlp' not in self.weight_info.name: + return + + res = {} + # pairwise similarity of corresponding rows in weights matrix + + res['cossim_full'] = F.cosine_similarity(tensors[0], tensors[1], dim=1) # this might get memory intensive, consider binning + res['cossim_mean'] = res['cossim_full'].mean() + res['cossim_std'] = res['cossim_full'].std() + + return res + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +class CossimMetric(MetricMethod): + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: GatherTensors, + **_kwargs, + ) -> Task: + return CossimMetricTask( + gather_tensors=tensors, + weight_info=output_weight, + ) diff --git a/mergekit/metric_methods/scale.py b/mergekit/metric_methods/scale.py new file mode 100644 index 00000000..b6c2a3a8 --- /dev/null +++ b/mergekit/metric_methods/scale.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod + +import torch.nn.functional as F + +class ScaleMetricTask(Task[torch.Tensor]): + """ + Relative difference in scale per neuron. Complemetary to the cosine similarity metric. + + scale_diff (X, Y) = absolute value of the difference in magnitude of X and Y, normalized by the average magnitude of X and Y + """ + gather_tensors: GatherTensors + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + keys = list(tensors.keys()) + tensors = [tensors[key] for key in keys] + + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError( + f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" + ) + if len(tensors) != 2: + raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") + + if 'mlp' not in self.weight_info.name: + return + + res = {} + + # Ensure the tensors have the same shape + assert tensors[0].shape == tensors[0].shape, "Tensors must have the same shape" + + # + norm_0 = torch.norm(tensors[0], dim=1) + norm_1 = torch.norm(tensors[1], dim=1) + + scale_diff = torch.abs(norm_0 - norm_1) + scale_diff = scale_diff / ((norm_0 + norm_1) / 2) + + res['scale_full'] = scale_diff + res['scale_mean'] = scale_diff.mean() + res['scale_std'] = scale_diff.std() + + return res + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +class ScaleMetric(MetricMethod): + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: GatherTensors, + **_kwargs, + ) -> Task: + return ScaleMetricTask( + gather_tensors=tensors, + weight_info=output_weight, + ) From 64e54bf4b834866f6a08e8ec4dfba78479e4cdb8 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Thu, 30 May 2024 15:33:54 +0100 Subject: [PATCH 02/64] measure.py closely follows metric.py to apply chosen metrics --- mergekit/measure.py | 92 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 mergekit/measure.py diff --git a/mergekit/measure.py b/mergekit/measure.py new file mode 100644 index 00000000..95e78028 --- /dev/null +++ b/mergekit/measure.py @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging + +import tqdm +import transformers + +from mergekit.architecture import ArchitectureInfo, get_architecture_info +from mergekit.config import MergeConfiguration +from mergekit.graph import Executor +from mergekit.io.tasks import LoaderCache +from mergekit.options import MergeOptions +from mergekit.plan import MergePlanner +from mergekit.merge import _model_out_config + + +def run_measure( + merge_config: MergeConfiguration, + out_path: str, + options: MergeOptions, +): + if options.random_seed is not None: + transformers.trainer_utils.set_seed(options.random_seed) + + if not merge_config.models and not merge_config.slices: + raise RuntimeError("No output requested") + + model_arch_info = [ + get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) + for m in merge_config.referenced_models() + ] + if not options.allow_crimes: + if not all(a == model_arch_info[0] for a in model_arch_info[1:]): + raise RuntimeError( + "Must specify --allow-crimes to attempt to mix different architectures" + ) + arch_info = model_arch_info[0] + + # initialize loader cache and set options + loader_cache = LoaderCache() + loader_cache.setup(options=options) + + # create config for output model + cfg_out = _model_out_config( + merge_config, arch_info, trust_remote_code=options.trust_remote_code + ) + + # warm up loader cache + for model in ( + pbar := tqdm.tqdm( + merge_config.referenced_models(), + desc="Warmup loader cache", + disable=options.quiet, + ) + ): + loader_cache.get(model) + del pbar + + logging.info("Planning operations") + targets = MergePlanner( + merge_config, + arch_info, + options=options, + out_model_config=cfg_out, + ).plan_to_disk(out_path=out_path) + + exec = Executor( + tasks=targets, + math_device="cuda" if options.cuda else "cpu", + storage_device="cuda" if options.low_cpu_memory else "cpu", + ) + + res = [] + for _task, value in exec.run(quiet=options.quiet): + res.append((_task, value)) + + return res + +__all__ = ["MergeOptions", "run_merge"] From 7e37d692ed4bd28ba7fe0aacfee909e5bee02931 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Thu, 30 May 2024 15:34:56 +0100 Subject: [PATCH 03/64] plot tools include class to handle output of run_metrics, and a class to plot (simple) interactive graph --- mergekit/plot_tools/plot_tools.py | 222 ++++++++++++++++++++++++++++++ run_metrics.py | 55 ++++++++ 2 files changed, 277 insertions(+) create mode 100644 mergekit/plot_tools/plot_tools.py create mode 100644 run_metrics.py diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py new file mode 100644 index 00000000..a10fc074 --- /dev/null +++ b/mergekit/plot_tools/plot_tools.py @@ -0,0 +1,222 @@ +import numpy as np +from typing import List, Dict, Optional, Any, Tuple +from mergekit.graph import Task +import networkx as nx +import plotly.graph_objects as go +import matplotlib.pyplot as plt + +class MetricsHandler(): + def __init__(self): + self.all_metrics: Dict[str, Dict[str, Any]] = {} + self.all_stats: List = [] + self.layer_names: List[str] = [] + + def load_metrics(self, metrics: List[Tuple[Task, Dict[str, Any]]]): + stats = set() + for task, metric in metrics: + if metric is not None: + self.all_metrics[task.weight_info.name] = {'metric':metric, + 'weight_info':task.weight_info} + self.layer_names.append(task.weight_info.name) + stats.update(metric.keys()) + + self.all_stats = list(stats) + + def layers(self) -> List[str]: + return self.layer_names + + def stats(self) -> List[str]: + return self.all_stats + + def metric_at_layer(self, layer_name: str) -> Dict[str, Any]: + if layer_name not in self.all_metrics: + raise ValueError(f"Layer {layer_name} not found in metrics") + return self.all_metrics[layer_name]['metric'] + + def info_at_layer(self, layer_name: str): + if layer_name not in self.all_metrics: + raise ValueError(f"Layer {layer_name} not found in metrics") + return self.all_metrics[layer_name]['weight_info'] + + def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): + fig, ax = plt.subplots() + + ax_kwargs = ['ylabel', 'title', 'ylim', 'xticklabels'] + plot_kwargs = {k: v for k, v in kwargs.items() if k not in ax_kwargs} + + self._plot_with_optional_error_bars(ax, stat, plot_kwargs) + self._set_plot_attributes(ax, stat, ax_kwargs, **kwargs) + if save_to: + plt.savefig(save_to) + plt.show() + plt.close() + + def _plot_with_optional_error_bars(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): + """ + Plot the stat values with optional error bars. + + Args: + ax: The matplotlib Axes object. + stat_values (List[float]): The values of the stat to plot. + std_values (Optional[List[float]]): The standard deviation values for error bars. + **kwargs: Additional keyword arguments for plotting. + """ + std_values = None + if f'{stat}_mean' in self.all_stats: + std_stat = f"{stat}_std" + stat = f'{stat}_mean' + if std_stat in self.all_stats: + std_values = [self.all_metrics[layer]['metric'].get(std_stat) for layer in self.layer_names] + + assert (stat in self.all_stats), f"Stat {stat} not found in metrics" + stat_values = [self.all_metrics[layer]['metric'][stat] for layer in self.layer_names] + if std_values: + ax.errorbar(self.layer_names, stat_values, yerr=std_values, fmt='-o', **plot_kwargs) + else: + ax.plot(stat_values, **plot_kwargs) + + def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): + """ + Set the attributes of the plot. + + Args: + ax: The matplotlib Axes object. + stat (str): The name of the stat. + **kwargs: Additional keyword arguments for plot attributes. + """ + # Defaults + ax.set_ylabel(kwargs.get('ylabel', stat)) + ax.set_xticks(np.arange(len(self.layer_names))) + ax.set_xticklabels(self.layer_names, rotation=45) + ax.set_title(kwargs.get('title', f'{stat.capitalize()}')) + + # Set additional attributes + for kwarg in ax_kwargs: + if kwarg in kwargs: + getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) + + +class NeuralNetworkGraph: + def __init__(self, metrics: List[Tuple['Task', Dict[str, Any]]]): + self.metrics = metrics + self.metric_handler = MetricsHandler() + self.metric_handler.load_metrics(metrics) + self.hierarchy = [] + self.common_parts = self._find_common_parts() + self.graph = nx.DiGraph() + self._parse_task_names() + + def _find_common_parts(self) -> List[str]: + """ + Find common parts in all task names. + """ + if not self.metrics: + return [] + + common_parts = None + for task, _ in self.metrics: + parts = task.weight_info.name.split('.') + if common_parts is None: + common_parts = set(parts) + else: + common_parts.intersection_update(parts) + + return list(common_parts) + + def _remove_common_parts(self, name: str) -> str: + """ + Remove common parts from the task name. + """ + parts = name.split('.') + cleaned_parts = [part for part in parts if part not in self.common_parts] + return '.'.join(cleaned_parts) + + def _parse_task_names(self): + for task, _ in self.metrics: + name = task.weight_info.name + self.hierarchy.append(name) + + def _add_nodes_and_edges(self, hierarchy): + # Current implementation builds linear graph + # Parallel paths (heads, skips) not yet supported + prev = None + for name in hierarchy: + self.graph.add_node(name) + if prev: + self.graph.add_edge(prev, name) + prev = name + + def construct_graph(self): + self._add_nodes_and_edges(self.hierarchy) + + def plot_graph(self, save_to: str = None): + """ + Plot the graph using Plotly for interactivity. + """ + pos = nx.planar_layout(self.graph) + edge_x = [] + edge_y = [] + for edge in self.graph.edges(): + x0, y0 = pos[edge[0]] + x1, y1 = pos[edge[1]] + edge_x.extend([x0, x1, None]) + edge_y.extend([y0, y1, None]) + + edge_trace = go.Scatter( + x=edge_x, y=edge_y, + line=dict(width=1, color='#888'), + hoverinfo='none', + mode='lines') + + metric_to_plot = [m for m in self.metric_handler.stats() if 'mean' in m][0] + + node_x = [] + node_y = [] + node_text = [] + node_values = [] + for node in self.graph.nodes(): + x, y = pos[node] + node_x.append(x) + node_y.append(y) + metric_value = self.metric_handler.metric_at_layer(node)[metric_to_plot] + node_text.append(f"{self._remove_common_parts(node)}: {metric_value:.2f}{'%' if 'SMAPE' in metric_to_plot else ''}") + node_values.append(metric_value) + + # Normalize node values for coloring + norm = plt.Normalize(vmin=min(node_values), vmax=max(node_values)) + node_colors = [norm(value) for value in node_values] + + node_trace = go.Scatter( + x=node_x, y=node_y, + mode='markers+text', + text=node_text, + textposition='top center', + hoverinfo='text', + marker=dict( + showscale=True, + colorscale='Viridis', + color=node_colors, + cmin=min(node_values).item(), + cmax=max(node_values).item(), + size=10, + colorbar=dict( + thickness=15, + title='Metric Value', + xanchor='left', + titleside='right', + ), + line_width=2)) + + + fig = go.Figure(data=[edge_trace, node_trace], + layout=go.Layout( + showlegend=False, + hovermode='closest', + margin=dict(b=0, l=0, r=0, t=0), + xaxis=dict(showgrid=False, zeroline=False), + yaxis=dict(showgrid=False, zeroline=False))) + + if save_to: + fig.write_html(save_to) + fig.show() + diff --git a/run_metrics.py b/run_metrics.py new file mode 100644 index 00000000..1425407d --- /dev/null +++ b/run_metrics.py @@ -0,0 +1,55 @@ +#%% +OUTPUT_PATH = "./merged" # folder to store the result in +LORA_MERGE_CACHE = "/tmp" # change if you want to keep these for some reason +CONFIG_YML = "./examples/linear_small.yml" # merge configuration file +COPY_TOKENIZER = True # you want a tokenizer? yeah, that's what i thought +LAZY_UNPICKLE = False # experimental low-memory model loader +LOW_CPU_MEMORY = False # enable if you somehow have more VRAM than RAM+swap + +# actually do merge +import torch +import yaml + +from mergekit.config import MergeConfiguration +from mergekit.merge import MergeOptions, run_merge +from mergekit.measure import run_measure +import matplotlib.pyplot as plt +import numpy as np +from mergekit.plot_tools.plot_tools import MetricsHandler, NeuralNetworkGraph +#%% + +with open(CONFIG_YML, "r", encoding="utf-8") as fp: + merge_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) + +out = run_measure( + merge_config, + out_path=OUTPUT_PATH, + options=MergeOptions( + lora_merge_cache=LORA_MERGE_CACHE, + cuda=torch.cuda.is_available(), + copy_tokenizer=COPY_TOKENIZER, + lazy_unpickle=LAZY_UNPICKLE, + low_cpu_memory=LOW_CPU_MEMORY, + ), +) +# %% + +# %% +handler = MetricsHandler() +handler.load_metrics(out) +handler.stats() + +# %% +handler.layers() +# %% +# handler.line_plot('cossim_mean', title='Cosine Similarity Mean', xticklabels=[f'layer {i}' for i in range(len(handler.layers()))],ylim=(0.90,1.0))#, xticklabels=[f'layer {i}' for i in range(len(handler.layers()))]) +# %% + + + +# Example usage: +metrics = handler.layers() +nn_graph = NeuralNetworkGraph([pair for pair in out if pair[1] is not None]) +nn_graph.construct_graph() +nn_graph.plot_graph() +# %% From 2effd41c0387cd2c5c10c2978381fc250a876559 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Thu, 30 May 2024 15:35:17 +0100 Subject: [PATCH 04/64] only minor changes required to existing mergekit code --- examples/linear_small.yml | 10 ++++++++++ mergekit/plan.py | 30 ++++++++---------------------- 2 files changed, 18 insertions(+), 22 deletions(-) create mode 100644 examples/linear_small.yml diff --git a/examples/linear_small.yml b/examples/linear_small.yml new file mode 100644 index 00000000..add503f2 --- /dev/null +++ b/examples/linear_small.yml @@ -0,0 +1,10 @@ +models: + - model: BEE-spoke-data/smol_llama-220M-GQA + + - model: BEE-spoke-data/smol_llama-220M-openhermes + + # - model: psmathur/orca_mini_v3_13b + # - model: garage-bAInd/Platypus2-13B + +metric_method: scale +dtype: float32 diff --git a/mergekit/plan.py b/mergekit/plan.py index 90c4f698..e1d6734a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -17,7 +17,7 @@ from functools import lru_cache from typing import Any, List, Optional, Tuple -from mergekit import merge_methods +from mergekit import merge_methods, metric_methods from mergekit.architecture import ( ArchitectureInfo, ConfiguredArchitectureInfo, @@ -40,6 +40,7 @@ TensorWriterTask, ) from mergekit.merge_methods import MergeMethod +from mergekit.metric_methods import MetricMethod from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge from mergekit.options import MergeOptions from mergekit.tokenizer import BuildTokenizer @@ -66,7 +67,10 @@ def __init__( self.arch_info = arch_info self.options = options self.out_model_config = out_model_config - self._method = merge_methods.get(config.merge_method) + if getattr(config, "merge_method", None): + self._method = merge_methods.get(config.merge_method) + elif getattr(config, "metric_method", None): + self._method = metric_methods.get(config.metric_method) if config.tokenizer_source: self._tokenizer_task = BuildTokenizer( @@ -246,31 +250,13 @@ def plan_to_disk(self, out_path: str) -> List[Task]: """Plan the merge to be streamed to disk, returning a list of tasks.""" self._plan() - writer_task = TensorWriterTask( - out_path=out_path, - max_shard_size=self.options.out_shard_size, - safe_serialization=self.options.safe_serialization, - ) save_tasks = [] for weight, tensor_task in self._tensors: save_tasks.append( - SaveTensor( - tensor_name=weight.name, - tensor_task=tensor_task, - writer_task=writer_task, - clone=self.options.clone_tensors, - optional=weight.optional, - dtype=weight.force_dtype, - ) + tensor_task ) - finalize = FinalizeModel( - tensor_save_tasks=tuple(save_tasks), writer_task=writer_task - ) - res = save_tasks + [finalize] - if self._tokenizer_task: - res.append(self._tokenizer_task) - return res + return save_tasks def plan_in_memory(self) -> List[ReturnTensor]: """Plan the merge to be performed in memory.""" From ed59d9bfb51437c1202e79a80b6e8d0044f83aa5 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Thu, 30 May 2024 15:36:21 +0100 Subject: [PATCH 05/64] only minor changes required to existing mergekit --- mergekit/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mergekit/config.py b/mergekit/config.py index cff31921..f4180e29 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -82,7 +82,8 @@ class OutputSliceDefinition(BaseModel): class MergeConfiguration(BaseModel): - merge_method: str + merge_method: Optional[str] = None + metric_method: Optional[str] = None slices: Optional[List[OutputSliceDefinition]] = None models: Optional[List[InputModelDefinition]] = None parameters: Optional[Dict[str, ParameterSetting]] = None From 564d45c537a2069b51b9e2a5fd31f2a25af3dac9 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 3 Jun 2024 11:59:32 +0100 Subject: [PATCH 06/64] Implemented interactive dashboard for metrics visualisation --- mergekit/metric_methods/SMAPE.py | 91 --------------- mergekit/metric_methods/__init__.py | 19 +--- mergekit/metric_methods/all_metrics.py | 149 +++++++++++++++++++++++++ mergekit/metric_methods/cossim.py | 79 ------------- mergekit/metric_methods/scale.py | 93 --------------- mergekit/plot_tools/plot_tools.py | 144 +++++++++++++++--------- run_metrics.py | 79 +++++++++---- 7 files changed, 301 insertions(+), 353 deletions(-) delete mode 100644 mergekit/metric_methods/SMAPE.py create mode 100644 mergekit/metric_methods/all_metrics.py delete mode 100644 mergekit/metric_methods/cossim.py delete mode 100644 mergekit/metric_methods/scale.py diff --git a/mergekit/metric_methods/SMAPE.py b/mergekit/metric_methods/SMAPE.py deleted file mode 100644 index 53110d35..00000000 --- a/mergekit/metric_methods/SMAPE.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -from typing import Any, Dict, List, Optional - -import torch - -from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference -from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod - -import torch.nn.functional as F - -class SMAPEMetricTask(Task[torch.Tensor]): - """ - Symmetric Mean Absolute Percentage Error (SMAPE) - - SMAPE = 100 * |y - y_hat| / ((|y| + |y_hat|) / 2) - """ - gather_tensors: GatherTensors - weight_info: WeightInfo - - def uses_accelerator(self) -> bool: - return True - - def arguments(self) -> Dict[str, Task]: - return {"tensors": self.gather_tensors} - - def execute( - self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs - ) -> torch.Tensor: - keys = list(tensors.keys()) - tensors = [tensors[key] for key in keys] - - unique_shapes = set(t.shape for t in tensors) - if len(unique_shapes) != 1: - raise RuntimeError( - f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" - ) - if len(tensors) != 2: - raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") - - if 'mlp' not in self.weight_info.name: - return - - res = {} - - # Ensure the tensors have the same shape - assert tensors[0].shape == tensors[0].shape, "Tensors must have the same shape" - - # SMAPE - numerator = torch.abs(tensors[0] - tensors[1]) - denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) / 2 - smape = 100 * torch.mean(torch.div(numerator, denominator), dim=1) - - res['SMAPE_full'] = smape - res['SMAPE_mean'] = smape.mean() - res['SMAPE_std'] = smape.std() - - return res - - def group_label(self) -> Optional[str]: - return self.gather_tensors.group_label() - - -class SMAPEMetric(MetricMethod): - def make_task( - self, - *, - output_weight: WeightInfo, - tensors: GatherTensors, - **_kwargs, - ) -> Task: - return SMAPEMetricTask( - gather_tensors=tensors, - weight_info=output_weight, - ) diff --git a/mergekit/metric_methods/__init__.py b/mergekit/metric_methods/__init__.py index e3e46ad2..bd43c462 100644 --- a/mergekit/metric_methods/__init__.py +++ b/mergekit/metric_methods/__init__.py @@ -14,33 +14,24 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from mergekit.metric_methods.base import MetricMethod -from mergekit.metric_methods.cossim import CossimMetric from mergekit.metric_methods.PCA_rank import PCA_RankMetric from mergekit.metric_methods.MSE import MSEMetric -from mergekit.metric_methods.SMAPE import SMAPEMetric -from mergekit.metric_methods.scale import ScaleMetric +from mergekit.metric_methods.all_metrics import AllMetric def get(method: str) -> MetricMethod: - if method == "cossim": - return CossimMetric() - elif method == "PCA_rank": + if method == "PCA_rank": return PCA_RankMetric() elif method == "MSE": return MSEMetric() - elif method == "SMAPE": - return SMAPEMetric() - elif method == "scale": - return ScaleMetric() - raise RuntimeError(f"Unimplemented merge method {method}") + elif method == "all": + return AllMetric() + raise RuntimeError(f"Unimplemented metric method {method}") __all__ = [ "MetricMethod", "get", - "CossimMetric", "MSEMetric", - "SMAPEMetric", - "ScaleMetric", "PCA_RankMetric", ] diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py new file mode 100644 index 00000000..7d20d347 --- /dev/null +++ b/mergekit/metric_methods/all_metrics.py @@ -0,0 +1,149 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod +from mergekit.metric_methods.base import ConfigParameterDef +import torch.nn.functional as F +import numpy as np + +def binning(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: + bin_counts, bin_edges = np.histogram(tensor.numpy(), bins=n_bins) + bin_widths = np.diff(bin_edges) + return bin_counts, bin_edges, bin_widths + +def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError(f"Tensor size mismatch for {weight_info.name}, sizes: {list(unique_shapes)}") + if expected_tensors: + if len(tensors) != expected_tensors: + raise RuntimeError(f"Expected {expected_tensors} tensors, got {len(tensors)}") + +def SMAPE( + tensors: List[torch.Tensor], **_kwargs +) -> Dict[str, Any]: + numerator = torch.abs(tensors[0] - tensors[1]) + denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) + smape = torch.mean(torch.div(numerator, denominator), dim=1) + + hist_info = binning(smape, 100) + return { + 'SMAPE_full': { + 'count': hist_info[0], + 'edges': hist_info[1], + 'widths': hist_info[2] + }, + 'SMAPE_mean': smape.mean(), + 'SMAPE_std': smape.std() + } + +def cossim( + tensors: List[torch.Tensor], **_kwargs +) -> torch.Tensor: + cossim = F.cosine_similarity(tensors[0], tensors[1], dim=1) + if _kwargs.get('angular_distance'): + cossim = torch.acos(cossim.clamp(min=-1, max=1))/torch.pi + + hist_info = binning(cossim, 100) + return { + 'cossim_full': { + 'count': hist_info[0], + 'edges': hist_info[1], + 'widths': hist_info[2] + }, + 'cossim_mean': cossim.mean(), + 'cossim_std': cossim.std() + } + +def scale( + tensors: List[torch.Tensor], **_kwargs +) -> torch.Tensor: + + norm_0 = torch.norm(tensors[0], dim=1) + norm_1 = torch.norm(tensors[1], dim=1) + + scale_diff = torch.abs(norm_0 - norm_1) + scale_diff = scale_diff / ((norm_0 + norm_1) / 2) + + hist_info = binning(scale_diff, 100) + return { + 'scale_full': { + 'count': hist_info[0], + 'edges': hist_info[1], + 'widths': hist_info[2] + }, + 'scale_mean': scale_diff.mean(), + 'scale_std': scale_diff.std() + } + +class AllMetricTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + angular_distance: bool + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + tensors = list(tensors.values()) + validate_tensors(tensors, self.weight_info, expected_tensors=2) + if 'mlp' not in self.weight_info.name: + return + + res = {} + + res.update(cossim(tensors, angular_distance=self.angular_distance)) + res.update(SMAPE(tensors)) + res.update(scale(tensors)) + + return res + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +class AllMetric(MetricMethod): + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef(name="angular_distance", required=False, default_value=False), + ] + def make_task( + self, + *, + output_weight: WeightInfo, + parameters: Optional[Dict[str, Any]] = None, + tensors: GatherTensors, + **_kwargs, + ) -> Task: + return AllMetricTask( + gather_tensors=tensors, + weight_info=output_weight, + angular_distance=parameters["angular_distance"] + ) + + diff --git a/mergekit/metric_methods/cossim.py b/mergekit/metric_methods/cossim.py deleted file mode 100644 index 598c8035..00000000 --- a/mergekit/metric_methods/cossim.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -from typing import Any, Dict, List, Optional - -import torch - -from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference -from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod - -import torch.nn.functional as F - -class CossimMetricTask(Task[torch.Tensor]): - gather_tensors: GatherTensors - weight_info: WeightInfo - - def uses_accelerator(self) -> bool: - return True - - def arguments(self) -> Dict[str, Task]: - return {"tensors": self.gather_tensors} - - def execute( - self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs - ) -> torch.Tensor: - keys = list(tensors.keys()) - tensors = [tensors[key] for key in keys] - - unique_shapes = set(t.shape for t in tensors) - if len(unique_shapes) != 1: - raise RuntimeError( - f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" - ) - if len(tensors) != 2: - raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") - - if 'mlp' not in self.weight_info.name: - return - - res = {} - # pairwise similarity of corresponding rows in weights matrix - - res['cossim_full'] = F.cosine_similarity(tensors[0], tensors[1], dim=1) # this might get memory intensive, consider binning - res['cossim_mean'] = res['cossim_full'].mean() - res['cossim_std'] = res['cossim_full'].std() - - return res - - def group_label(self) -> Optional[str]: - return self.gather_tensors.group_label() - - -class CossimMetric(MetricMethod): - def make_task( - self, - *, - output_weight: WeightInfo, - tensors: GatherTensors, - **_kwargs, - ) -> Task: - return CossimMetricTask( - gather_tensors=tensors, - weight_info=output_weight, - ) diff --git a/mergekit/metric_methods/scale.py b/mergekit/metric_methods/scale.py deleted file mode 100644 index b6c2a3a8..00000000 --- a/mergekit/metric_methods/scale.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -from typing import Any, Dict, List, Optional - -import torch - -from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference -from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod - -import torch.nn.functional as F - -class ScaleMetricTask(Task[torch.Tensor]): - """ - Relative difference in scale per neuron. Complemetary to the cosine similarity metric. - - scale_diff (X, Y) = absolute value of the difference in magnitude of X and Y, normalized by the average magnitude of X and Y - """ - gather_tensors: GatherTensors - weight_info: WeightInfo - - def uses_accelerator(self) -> bool: - return True - - def arguments(self) -> Dict[str, Task]: - return {"tensors": self.gather_tensors} - - def execute( - self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs - ) -> torch.Tensor: - keys = list(tensors.keys()) - tensors = [tensors[key] for key in keys] - - unique_shapes = set(t.shape for t in tensors) - if len(unique_shapes) != 1: - raise RuntimeError( - f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" - ) - if len(tensors) != 2: - raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") - - if 'mlp' not in self.weight_info.name: - return - - res = {} - - # Ensure the tensors have the same shape - assert tensors[0].shape == tensors[0].shape, "Tensors must have the same shape" - - # - norm_0 = torch.norm(tensors[0], dim=1) - norm_1 = torch.norm(tensors[1], dim=1) - - scale_diff = torch.abs(norm_0 - norm_1) - scale_diff = scale_diff / ((norm_0 + norm_1) / 2) - - res['scale_full'] = scale_diff - res['scale_mean'] = scale_diff.mean() - res['scale_std'] = scale_diff.std() - - return res - - def group_label(self) -> Optional[str]: - return self.gather_tensors.group_label() - - -class ScaleMetric(MetricMethod): - def make_task( - self, - *, - output_weight: WeightInfo, - tensors: GatherTensors, - **_kwargs, - ) -> Task: - return ScaleMetricTask( - gather_tensors=tensors, - weight_info=output_weight, - ) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index a10fc074..bb65c3b9 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -6,37 +6,49 @@ import matplotlib.pyplot as plt class MetricsHandler(): + """ + Object to handle metrics output. Allows for easy plotting of metrics by layer and across layers. + + Input: + Use the load_metrics method to load the metrics into the handler. + metrics: List of tasks and their metrics. This is the output of the run_measure function in mergekit.measure. + + Attributes: + all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cossim_mean': 0.5, 'cossim_std': 0.1}} + stat_names: List of names of all statistics available. e.g. ['cossim_mean', 'cossim_std'] + layer_names: List of layer names. + + Methods: + load_metrics: Load the metrics into the handler. + stats_at_layer: Get the metrics for a specific layer. + info_at_layer: Get the weight info for a specific layer. + line_plot: Plot a line plot of the chosen stat across layers. + plot_node_hist: Plot a histogram of the stat for a specific layer. + """ def __init__(self): - self.all_metrics: Dict[str, Dict[str, Any]] = {} - self.all_stats: List = [] + self.all_stats: Dict[str, Dict[str, Any]] = {} + self.stat_names: List = [] self.layer_names: List[str] = [] def load_metrics(self, metrics: List[Tuple[Task, Dict[str, Any]]]): - stats = set() for task, metric in metrics: if metric is not None: - self.all_metrics[task.weight_info.name] = {'metric':metric, + self.all_stats[task.weight_info.name] = {'metric':metric, 'weight_info':task.weight_info} self.layer_names.append(task.weight_info.name) - stats.update(metric.keys()) + self.stat_names.extend(metric.keys()) - self.all_stats = list(stats) - - def layers(self) -> List[str]: - return self.layer_names + self.stat_names = list(set(self.stat_names)) - def stats(self) -> List[str]: - return self.all_stats - - def metric_at_layer(self, layer_name: str) -> Dict[str, Any]: - if layer_name not in self.all_metrics: - raise ValueError(f"Layer {layer_name} not found in metrics") - return self.all_metrics[layer_name]['metric'] + def stats_at_layer(self, layer_name: str) -> Dict[str, Any]: + if layer_name not in self.all_stats: + raise ValueError(f"Layer {layer_name} not found") + return self.all_stats[layer_name]['metric'] def info_at_layer(self, layer_name: str): - if layer_name not in self.all_metrics: - raise ValueError(f"Layer {layer_name} not found in metrics") - return self.all_metrics[layer_name]['weight_info'] + if layer_name not in self.all_stats: + raise ValueError(f"Layer {layer_name} not found") + return self.all_stats[layer_name]['weight_info'] def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): fig, ax = plt.subplots() @@ -44,14 +56,14 @@ def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): ax_kwargs = ['ylabel', 'title', 'ylim', 'xticklabels'] plot_kwargs = {k: v for k, v in kwargs.items() if k not in ax_kwargs} - self._plot_with_optional_error_bars(ax, stat, plot_kwargs) + self._plot(ax, stat, plot_kwargs) self._set_plot_attributes(ax, stat, ax_kwargs, **kwargs) if save_to: plt.savefig(save_to) plt.show() plt.close() - def _plot_with_optional_error_bars(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): + def _plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): """ Plot the stat values with optional error bars. @@ -62,14 +74,14 @@ def _plot_with_optional_error_bars(self, ax, stat:str, plot_kwargs: Optional[Dic **kwargs: Additional keyword arguments for plotting. """ std_values = None - if f'{stat}_mean' in self.all_stats: + if f'{stat}_mean' in self.stat_names: std_stat = f"{stat}_std" stat = f'{stat}_mean' - if std_stat in self.all_stats: - std_values = [self.all_metrics[layer]['metric'].get(std_stat) for layer in self.layer_names] + if std_stat in self.stat_names: + std_values = [self.all_stats[layer]['metric'].get(std_stat) for layer in self.layer_names] - assert (stat in self.all_stats), f"Stat {stat} not found in metrics" - stat_values = [self.all_metrics[layer]['metric'][stat] for layer in self.layer_names] + assert (stat in self.stat_names), f"Stat {stat} not found" + stat_values = [self.all_stats[layer]['metric'][stat] for layer in self.layer_names] if std_values: ax.errorbar(self.layer_names, stat_values, yerr=std_values, fmt='-o', **plot_kwargs) else: @@ -88,17 +100,33 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): ax.set_ylabel(kwargs.get('ylabel', stat)) ax.set_xticks(np.arange(len(self.layer_names))) ax.set_xticklabels(self.layer_names, rotation=45) - ax.set_title(kwargs.get('title', f'{stat.capitalize()}')) + ax.set_title(kwargs.get('title', f'{stat.replace("_", " ").capitalize()}')) # Set additional attributes for kwarg in ax_kwargs: if kwarg in kwargs: getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) + + def plot_node_hist(self, layer_name: str, stat: str): + + bin_counts, bin_edges, bin_widths = self.all_stats[layer_name]['metric'][stat].values() + # Create a bar chart using Plotly + return go.Bar( + x=bin_edges[:-1], + y=bin_counts, + width=bin_widths, + marker=dict( + color='blue', + line=dict( + color='black', + width=1 + ) + ) + ) -class NeuralNetworkGraph: +class ModelGraph: def __init__(self, metrics: List[Tuple['Task', Dict[str, Any]]]): - self.metrics = metrics self.metric_handler = MetricsHandler() self.metric_handler.load_metrics(metrics) self.hierarchy = [] @@ -110,12 +138,9 @@ def _find_common_parts(self) -> List[str]: """ Find common parts in all task names. """ - if not self.metrics: - return [] - common_parts = None - for task, _ in self.metrics: - parts = task.weight_info.name.split('.') + for task_name, _ in self.metric_handler.all_stats.items(): + parts = task_name.split('.') if common_parts is None: common_parts = set(parts) else: @@ -132,9 +157,8 @@ def _remove_common_parts(self, name: str) -> str: return '.'.join(cleaned_parts) def _parse_task_names(self): - for task, _ in self.metrics: - name = task.weight_info.name - self.hierarchy.append(name) + for task_name, _ in self.metric_handler.all_stats.items(): + self.hierarchy.append(task_name) def _add_nodes_and_edges(self, hierarchy): # Current implementation builds linear graph @@ -153,7 +177,10 @@ def plot_graph(self, save_to: str = None): """ Plot the graph using Plotly for interactivity. """ - pos = nx.planar_layout(self.graph) + # Manually set positions for a straight line layout. + # Not yet implemented for more complex layouts with Parallel paths + pos = {node: (i, i/5) for i, node in enumerate(self.graph.nodes())} + edge_x = [] edge_y = [] for edge in self.graph.edges(): @@ -168,23 +195,31 @@ def plot_graph(self, save_to: str = None): hoverinfo='none', mode='lines') - metric_to_plot = [m for m in self.metric_handler.stats() if 'mean' in m][0] + # Find all metrics that contain 'mean' in their keys + metrics_to_plot = [m for m in self.metric_handler.stat_names if 'mean' in m] + + node_x,node_y,node_text,hover_text = [], [], [], [] + node_values = {metric: [] for metric in metrics_to_plot} - node_x = [] - node_y = [] - node_text = [] - node_values = [] for node in self.graph.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) - metric_value = self.metric_handler.metric_at_layer(node)[metric_to_plot] - node_text.append(f"{self._remove_common_parts(node)}: {metric_value:.2f}{'%' if 'SMAPE' in metric_to_plot else ''}") - node_values.append(metric_value) + metric_values = self.metric_handler.stats_at_layer(node) + + # Build the text for each node + hover = self._remove_common_parts(node) + for metric in metrics_to_plot: + if metric in metric_values: + value = metric_values[metric] + hover += f"
{metric.replace('_', ' ').capitalize()}: {value:.4f}{'%' if 'SMAPE' in metric else ''}" + node_values[metric].append(value) - # Normalize node values for coloring - norm = plt.Normalize(vmin=min(node_values), vmax=max(node_values)) - node_colors = [norm(value) for value in node_values] + node_text.append(node) + hover_text.append(hover) + + first_metric = metrics_to_plot[0] + node_colors = [value.item() for value in node_values[first_metric]] node_trace = go.Scatter( x=node_x, y=node_y, @@ -192,22 +227,22 @@ def plot_graph(self, save_to: str = None): text=node_text, textposition='top center', hoverinfo='text', + hovertext=hover_text, marker=dict( showscale=True, colorscale='Viridis', color=node_colors, - cmin=min(node_values).item(), - cmax=max(node_values).item(), + cmin=min(node_values[first_metric]).item(), + cmax=max(node_values[first_metric]).item(), size=10, colorbar=dict( thickness=15, - title='Metric Value', + title=first_metric.replace('_', ' ').capitalize(), xanchor='left', titleside='right', ), line_width=2)) - fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( showlegend=False, @@ -218,5 +253,6 @@ def plot_graph(self, save_to: str = None): if save_to: fig.write_html(save_to) - fig.show() + return fig + diff --git a/run_metrics.py b/run_metrics.py index 1425407d..1e445e48 100644 --- a/run_metrics.py +++ b/run_metrics.py @@ -6,17 +6,17 @@ LAZY_UNPICKLE = False # experimental low-memory model loader LOW_CPU_MEMORY = False # enable if you somehow have more VRAM than RAM+swap -# actually do merge import torch import yaml from mergekit.config import MergeConfiguration -from mergekit.merge import MergeOptions, run_merge +from mergekit.merge import MergeOptions from mergekit.measure import run_measure -import matplotlib.pyplot as plt -import numpy as np -from mergekit.plot_tools.plot_tools import MetricsHandler, NeuralNetworkGraph -#%% +from mergekit.plot_tools.plot_tools import ModelGraph +import dash +from dash import dcc, html +from dash.dependencies import Input, Output +import plotly.graph_objects as go with open(CONFIG_YML, "r", encoding="utf-8") as fp: merge_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) @@ -32,24 +32,59 @@ low_cpu_memory=LOW_CPU_MEMORY, ), ) -# %% -# %% -handler = MetricsHandler() -handler.load_metrics(out) -handler.stats() +nn_graph = ModelGraph([pair for pair in out if pair[1] is not None]) +nn_graph.construct_graph() + +# Initialize the Dash app +app = dash.Dash(__name__) -# %% -handler.layers() -# %% -# handler.line_plot('cossim_mean', title='Cosine Similarity Mean', xticklabels=[f'layer {i}' for i in range(len(handler.layers()))],ylim=(0.90,1.0))#, xticklabels=[f'layer {i}' for i in range(len(handler.layers()))]) -# %% +app.layout = html.Div([ + html.Label('Model Architecture Graph | Hover for node stats | Click for node plots', style={'font-family': 'Arial'}), + dcc.Graph( + id='graph', + figure=nn_graph.plot_graph(), + ), + dcc.Graph(id='node-details'), + html.Label('Select Metric:', style={'font-family': 'Arial'}), + # Add a dropdown menu to select the metric + dcc.Dropdown( + id='metric-dropdown', + options=[ + {'label': 'SMAPE', 'value': 'SMAPE_full'}, + {'label': 'Cossim', 'value': 'cossim_full'}, + {'label': 'Scale', 'value': 'scale_full'} + ], + value='cossim_full', + style={'font-family': 'Arial'} + ) +]) +@app.callback( + Output('node-details', 'figure'), + [Input('graph', 'clickData'), + Input('metric-dropdown', 'value')]) +def display_node_data(clickData, selected_metric): + if clickData is None: + print("No clickData received") + return go.Figure() + + try: + node_name = clickData['points'][0]['text'] + except (KeyError, IndexError, TypeError) as e: + print(f"Error processing clickData: {e}") + return go.Figure() + + fig = go.Figure() + fig.add_trace( + nn_graph.metric_handler.plot_node_hist(node_name, stat=selected_metric) + ) + fig.update_layout(title=f"Metrics for {node_name} | {selected_metric}", + xaxis_title="Metric", + yaxis_title="Value") -# Example usage: -metrics = handler.layers() -nn_graph = NeuralNetworkGraph([pair for pair in out if pair[1] is not None]) -nn_graph.construct_graph() -nn_graph.plot_graph() -# %% + return fig + +if __name__ == '__main__': + app.run_server() \ No newline at end of file From e5305684cc48e0083a59c0fcd9a09a8cc2ebae8b Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 3 Jun 2024 12:19:31 +0100 Subject: [PATCH 07/64] Remove single-model stats for now. Bring MSE into all_metrics --- mergekit/metric_methods/MSE.py | 87 ----------------- mergekit/metric_methods/PCA_rank.py | 125 ------------------------- mergekit/metric_methods/__init__.py | 10 +- mergekit/metric_methods/all_metrics.py | 22 +++++ run_metrics.py | 3 +- 5 files changed, 25 insertions(+), 222 deletions(-) delete mode 100644 mergekit/metric_methods/MSE.py delete mode 100644 mergekit/metric_methods/PCA_rank.py diff --git a/mergekit/metric_methods/MSE.py b/mergekit/metric_methods/MSE.py deleted file mode 100644 index 6d36089e..00000000 --- a/mergekit/metric_methods/MSE.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -from typing import Any, Dict, List, Optional - -import torch - -from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference -from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod - -class MSEMetricTask(Task[torch.Tensor]): - gather_tensors: GatherTensors - weight_info: WeightInfo - - def uses_accelerator(self) -> bool: - return True - - def arguments(self) -> Dict[str, Task]: - return {"tensors": self.gather_tensors} - - def execute( - self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs - ) -> torch.Tensor: - keys = list(tensors.keys()) - tensors = [tensors[key] for key in keys] - - unique_shapes = set(t.shape for t in tensors) - if len(unique_shapes) != 1: - raise RuntimeError( - f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" - ) - if len(tensors) != 2: - raise RuntimeError(f"Expected 2 tensors, got {len(tensors)}") - - if 'mlp' not in self.weight_info.name: - return - - res = {} - # pairwise similarity of corresponding rows in weights matrix - - # Ensure the tensors have the same shape - assert tensors[0].shape == tensors[0].shape, "Tensors must have the same shape" - - # Compute the squared differences - squared_diff = (tensors[0] - tensors[1]) ** 2 - - - # Compute the mean of squared differences for each row - mse_per_neuron = torch.mean(squared_diff, dim=1) - - res['MSE_full'] = mse_per_neuron - res['MSE_mean'] = mse_per_neuron.mean() - res['MSE_std'] = mse_per_neuron.std() - - return res - - def group_label(self) -> Optional[str]: - return self.gather_tensors.group_label() - - -class MSEMetric(MetricMethod): - def make_task( - self, - *, - output_weight: WeightInfo, - tensors: GatherTensors, - **_kwargs, - ) -> Task: - return MSEMetricTask( - gather_tensors=tensors, - weight_info=output_weight, - ) diff --git a/mergekit/metric_methods/PCA_rank.py b/mergekit/metric_methods/PCA_rank.py deleted file mode 100644 index f268bed6..00000000 --- a/mergekit/metric_methods/PCA_rank.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -from typing import Any, Dict, List, Optional - -import torch - -from mergekit.architecture import WeightInfo -from mergekit.common import ImmutableMap, ModelReference -from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod -from mergekit.merge_methods.base import ConfigParameterDef -from mergekit.merge_methods.rectify_embed import rectify_embed_sizes - -import torch.nn.functional as F - -def pca_components_for_variance(X, variance_threshold=0.99, rescale=True): - """ - Compute the number of principal components required to explain - at least `variance_threshold` of the total variance in the dataset X using PyTorch. - - Args: - X (torch.Tensor): The data matrix. Rows are samples and columns are features. - variance_threshold (float): The fraction of total variance that we want to capture. - - Returns: - int: The number of principal components required to capture the specified variance threshold. - """ - # Standardize the data (mean 0 and variance 1) - X_mean = torch.mean(X, dim=0) - X_std = torch.std(X, dim=0, unbiased=False) - X = X - X_mean - - if rescale: - X = X / X_std - - # Compute the covariance matrix - covariance_matrix = torch.mm(X.T, X) / (X.shape[0] - 1) - - # Perform SVD on the covariance matrix - U, S, V = torch.svd(covariance_matrix) - - # Calculate explained variance ratios - explained_variance_ratio = S / torch.sum(S) - cumsum_variance = torch.cumsum(explained_variance_ratio, dim=0) - - # Determine the number of components needed to surpass the variance threshold - num_components = torch.where(cumsum_variance >= variance_threshold)[0][0] + 1 - - return num_components.item() - - -class PCA_RankTask(Task[torch.Tensor]): - gather_tensors: GatherTensors - tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] - normalize: bool - weight_info: WeightInfo - - def uses_accelerator(self) -> bool: - return True - - def arguments(self) -> Dict[str, Task]: - return {"tensors": self.gather_tensors} - - def execute( - self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs - ) -> torch.Tensor: - keys = list(tensors.keys()) - - tensors = [tensors[key] for key in keys] - - - unique_shapes = set(t.shape for t in tensors) - if len(unique_shapes) != 1: - raise RuntimeError( - f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" - ) - if len(tensors) != 1: - raise RuntimeError(f"Expected 1 tensors, got {len(tensors)}") - - if 'mlp' not in self.weight_info.name: - return - - res = {} - X = tensors[0] - - res['num_components_99'] = pca_components_for_variance(X, variance_threshold=0.99, rescale=True) - res['num_components_95'] = pca_components_for_variance(X, variance_threshold=0.95, rescale=True) - return res - - - def group_label(self) -> Optional[str]: - return self.gather_tensors.group_label() - - -class PCA_RankMetric(MetricMethod): - - def make_task( - self, - *, - output_weight: WeightInfo, - tensors: GatherTensors, - parameters: Dict[str, Any], - tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], - **_kwargs, - ) -> Task: - return PCA_RankTask( - gather_tensors=tensors, - tensor_parameters=tensor_parameters, - normalize=parameters["normalize"], - weight_info=output_weight, - ) diff --git a/mergekit/metric_methods/__init__.py b/mergekit/metric_methods/__init__.py index bd43c462..6027c8b6 100644 --- a/mergekit/metric_methods/__init__.py +++ b/mergekit/metric_methods/__init__.py @@ -14,17 +14,11 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from mergekit.metric_methods.base import MetricMethod -from mergekit.metric_methods.PCA_rank import PCA_RankMetric -from mergekit.metric_methods.MSE import MSEMetric from mergekit.metric_methods.all_metrics import AllMetric def get(method: str) -> MetricMethod: - if method == "PCA_rank": - return PCA_RankMetric() - elif method == "MSE": - return MSEMetric() - elif method == "all": + if method == "all": return AllMetric() raise RuntimeError(f"Unimplemented metric method {method}") @@ -32,6 +26,4 @@ def get(method: str) -> MetricMethod: __all__ = [ "MetricMethod", "get", - "MSEMetric", - "PCA_RankMetric", ] diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index 7d20d347..4fc7f46a 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -96,6 +96,27 @@ def scale( 'scale_std': scale_diff.std() } +def mse( + tensors: List[torch.Tensor], **_kwargs +) -> torch.Tensor: + # Compute the squared differences + squared_diff = (tensors[0] - tensors[1]) ** 2 + + + # Compute the mean of squared differences for each row + mse_per_neuron = torch.mean(squared_diff, dim=1) + + hist_info = binning(mse_per_neuron, 100) + return { + 'mse_full': { + 'count': hist_info[0], + 'edges': hist_info[1], + 'widths': hist_info[2] + }, + 'mse_mean': mse_per_neuron.mean(), + 'mse_std': mse_per_neuron.std() + } + class AllMetricTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo @@ -120,6 +141,7 @@ def execute( res.update(cossim(tensors, angular_distance=self.angular_distance)) res.update(SMAPE(tensors)) res.update(scale(tensors)) + res.update(mse(tensors)) return res diff --git a/run_metrics.py b/run_metrics.py index 1e445e48..5ef8ee37 100644 --- a/run_metrics.py +++ b/run_metrics.py @@ -53,7 +53,8 @@ options=[ {'label': 'SMAPE', 'value': 'SMAPE_full'}, {'label': 'Cossim', 'value': 'cossim_full'}, - {'label': 'Scale', 'value': 'scale_full'} + {'label': 'Scale', 'value': 'scale_full'}, + {'label': 'MSE', 'value': 'mse_full'}, ], value='cossim_full', style={'font-family': 'Arial'} From 207e874e5a7c59930a5b403b5fca243bbe647082 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Fri, 7 Jun 2024 18:11:21 +0100 Subject: [PATCH 08/64] Introduce attention weights and restructure dashboard --- mergekit/metric_methods/all_metrics.py | 181 ++++++++++++++++++++----- mergekit/plot_tools/plot_tools.py | 166 ++++++++++++++++++++++- run_metrics.py | 87 +++--------- 3 files changed, 321 insertions(+), 113 deletions(-) diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index 4fc7f46a..d9848013 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -15,18 +15,19 @@ from typing import Any, Dict, List, Optional -import torch - from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference +from mergekit.common import ModelReference, ImmutableMap from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors +from mergekit.io.tasks import GatherTensors, LoadTensor from mergekit.metric_methods.base import MetricMethod from mergekit.metric_methods.base import ConfigParameterDef + +import torch import torch.nn.functional as F + import numpy as np -def binning(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: +def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: bin_counts, bin_edges = np.histogram(tensor.numpy(), bins=n_bins) bin_widths = np.diff(bin_edges) return bin_counts, bin_edges, bin_widths @@ -46,15 +47,15 @@ def SMAPE( denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) smape = torch.mean(torch.div(numerator, denominator), dim=1) - hist_info = binning(smape, 100) + hist_info = compute_histogram(smape, 100) return { 'SMAPE_full': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] }, - 'SMAPE_mean': smape.mean(), - 'SMAPE_std': smape.std() + 'SMAPE_mean': smape.mean().item(), + 'SMAPE_std': smape.std().item() } def cossim( @@ -64,15 +65,15 @@ def cossim( if _kwargs.get('angular_distance'): cossim = torch.acos(cossim.clamp(min=-1, max=1))/torch.pi - hist_info = binning(cossim, 100) + hist_info = compute_histogram(cossim, 100) return { 'cossim_full': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] }, - 'cossim_mean': cossim.mean(), - 'cossim_std': cossim.std() + 'cossim_mean': cossim.mean().item(), + 'cossim_std': cossim.std().item() } def scale( @@ -82,39 +83,68 @@ def scale( norm_0 = torch.norm(tensors[0], dim=1) norm_1 = torch.norm(tensors[1], dim=1) - scale_diff = torch.abs(norm_0 - norm_1) - scale_diff = scale_diff / ((norm_0 + norm_1) / 2) + scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) - hist_info = binning(scale_diff, 100) + hist_info = compute_histogram(scale_diff, 100) return { 'scale_full': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] }, - 'scale_mean': scale_diff.mean(), - 'scale_std': scale_diff.std() + 'scale_mean': scale_diff.mean().item(), + 'scale_std': scale_diff.std().item() } def mse( tensors: List[torch.Tensor], **_kwargs ) -> torch.Tensor: - # Compute the squared differences + squared_diff = (tensors[0] - tensors[1]) ** 2 - - - # Compute the mean of squared differences for each row mse_per_neuron = torch.mean(squared_diff, dim=1) - hist_info = binning(mse_per_neuron, 100) + hist_info = compute_histogram(mse_per_neuron, 100) return { 'mse_full': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] }, - 'mse_mean': mse_per_neuron.mean(), - 'mse_std': mse_per_neuron.std() + 'mse_mean': mse_per_neuron.mean().item(), + 'mse_std': mse_per_neuron.std().item() + } + +def restructure_tensor(input_tensor, num_columns): + """ + Restructure a tensor by splitting its columns. + + Args: + input_tensor (torch.Tensor): The input tensor to restructure. + num_columns (int): The number of columns for splitting. + + Returns: + torch.Tensor: The restructured tensor. + """ + rows, cols = input_tensor.shape + new_cols = cols // num_columns + reshaped_tensor = input_tensor.view(rows, num_columns, new_cols) + restructured_tensor = reshaped_tensor.permute(1, 0, 2) + + return restructured_tensor + +def compare_attn_head_weights(k_proj, q_proj, v_proj, o_proj, num_heads, **_kwargs): + models = list(q_proj.keys()) + q_proj_0 = restructure_tensor(q_proj[models[0]], num_heads) + q_proj_1 = restructure_tensor(q_proj[models[1]], num_heads) + + # Now the first dimension is the head index, so can be compared pairwise or even cross compared within/between models. + heatmap = np.zeros((num_heads, num_heads)) + for i in range(num_heads): + for j in range(num_heads): + heatmap[i, j] = ((q_proj_0[i].flatten() - q_proj_1[j].flatten()) ** 2).mean().item() + + return { + 'MSE Attn Heatmap': heatmap, } class AllMetricTask(Task[torch.Tensor]): @@ -133,23 +163,81 @@ def execute( ) -> torch.Tensor: tensors = list(tensors.values()) validate_tensors(tensors, self.weight_info, expected_tensors=2) - if 'mlp' not in self.weight_info.name: - return - res = {} + if 'mlp' in self.weight_info.name: - res.update(cossim(tensors, angular_distance=self.angular_distance)) - res.update(SMAPE(tensors)) - res.update(scale(tensors)) - res.update(mse(tensors)) + res.update(cossim(tensors, angular_distance=self.angular_distance)) + res.update(SMAPE(tensors)) + res.update(scale(tensors)) + res.update(mse(tensors)) return res def group_label(self) -> Optional[str]: return self.gather_tensors.group_label() +class AttnTask(Task[torch.Tensor]): + weights: Dict[str, GatherTensors] + weight_infos: Dict[str, WeightInfo] + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + + return self.weights + + def execute( + self, k_proj, v_proj, q_proj, o_proj, **_kwargs + ) -> torch.Tensor: + # Add metrics for attention weights + res = {} + res.update(compare_attn_head_weights(k_proj, q_proj, v_proj, o_proj, num_heads=32)) # 32 is a placeholder + + return res + + def group_label(self) -> Optional[str]: + # Use max of the group labels + return max([gather_tensor.group_label() for gather_tensor in list(self.weights.values())]) # Check this (X) + + def __hash__(self): + return hash((tuple(self.weight_infos),)) + + def __eq__(self, other): + if not isinstance(other, AttnTask): + return False + return self.weight_infos == other.weight_infos + +class blankTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, **_kwargs + ) -> torch.Tensor: + + return + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + class AllMetric(MetricMethod): + attn_weight_tensors: Optional[list] = [] + attn_weight_infos: Optional[list] = [] + + attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} + attn_info_dict: Optional[Dict[str, WeightInfo]] = {} + + + attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] + block_count: Optional[int] = 0 def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="angular_distance", required=False, default_value=False), @@ -162,10 +250,33 @@ def make_task( tensors: GatherTensors, **_kwargs, ) -> Task: - return AllMetricTask( - gather_tensors=tensors, - weight_info=output_weight, - angular_distance=parameters["angular_distance"] - ) + + if 'self_attn' in output_weight.name: + for part in self.attn_parts: # also check only one key + if part in output_weight.name: + self.attn_weight_dict[part] = tensors + self.attn_info_dict[part] = output_weight + + if set(list(self.attn_weight_dict.keys())) == set(self.attn_parts): + weights = self.attn_weight_dict + infos = self.attn_info_dict + self.attn_weight_dict = {} + self.attn_info_dict = {} + weight_info = WeightInfo( + name=f"Attention Block {self.block_count}", + force_dtype=None, + optional=False, + aliases=None, + ) + self.block_count += 1 + return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) + if 'mlp' in output_weight.name: + return AllMetricTask( + gather_tensors=tensors, + weight_info=output_weight, + angular_distance=parameters["angular_distance"] + ) + else: + return blankTask(gather_tensors=tensors, weight_info=output_weight) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index bb65c3b9..37c099de 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -4,8 +4,11 @@ import networkx as nx import plotly.graph_objects as go import matplotlib.pyplot as plt +import dash +from dash import dcc, html +from dash.dependencies import Input, Output, State -class MetricsHandler(): +class MetricsHandler: """ Object to handle metrics output. Allows for easy plotting of metrics by layer and across layers. @@ -56,14 +59,14 @@ def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): ax_kwargs = ['ylabel', 'title', 'ylim', 'xticklabels'] plot_kwargs = {k: v for k, v in kwargs.items() if k not in ax_kwargs} - self._plot(ax, stat, plot_kwargs) + self._line_plot(ax, stat, plot_kwargs) self._set_plot_attributes(ax, stat, ax_kwargs, **kwargs) if save_to: plt.savefig(save_to) plt.show() plt.close() - def _plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): + def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): """ Plot the stat values with optional error bars. @@ -87,6 +90,23 @@ def _plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): else: ax.plot(stat_values, **plot_kwargs) + def heatmap_plot(self, layer_name:str, stat:str): + """ + Plot the stat values as a heatmap. + + Args: + layer_name (str): The name of the layer. + stat (str): The name of the stat to plot. + Returns: + go.Heatmap: Plotly Heatmap object. + """ + heatmap = self.all_stats[layer_name]['metric'][stat] + + return go.Heatmap( + z=heatmap, + colorscale='RdBu' + ) + def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): """ Set the attributes of the plot. @@ -110,7 +130,6 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): def plot_node_hist(self, layer_name: str, stat: str): bin_counts, bin_edges, bin_widths = self.all_stats[layer_name]['metric'][stat].values() - # Create a bar chart using Plotly return go.Bar( x=bin_edges[:-1], y=bin_counts, @@ -219,7 +238,7 @@ def plot_graph(self, save_to: str = None): hover_text.append(hover) first_metric = metrics_to_plot[0] - node_colors = [value.item() for value in node_values[first_metric]] + node_colors = [value for value in node_values[first_metric]] node_trace = go.Scatter( x=node_x, y=node_y, @@ -232,8 +251,8 @@ def plot_graph(self, save_to: str = None): showscale=True, colorscale='Viridis', color=node_colors, - cmin=min(node_values[first_metric]).item(), - cmax=max(node_values[first_metric]).item(), + cmin=min(node_values[first_metric]), + cmax=max(node_values[first_metric]), size=10, colorbar=dict( thickness=15, @@ -256,3 +275,136 @@ def plot_graph(self, save_to: str = None): return fig +def create_app(nn_graph): + """ + Creates and configures a Dash app to visualize metrics from a neural network graph. + + Args: + nn_graph (ModelGraph): An instance of the neural network graph to be visualized. + + Returns: + app (dash.Dash): Configured Dash app ready to be run. + """ + # Initialize the Dash app + app = dash.Dash(__name__) + + # Define the layout of the app + app.layout = html.Div([ + dcc.Graph( + id='graph', + figure=nn_graph.plot_graph(), # Initial plot of the neural network graph + ), + dcc.Dropdown( + id='metric-dropdown', + options=[], # Options will be set dynamically based on the node type + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} + ), + dcc.Graph(id='node-details'), # Placeholder for detailed metrics of a selected node + dcc.Store(id='node-type-store') # Store to keep track of the node type + ]) + + # Callback to update the node type when a node in the graph is clicked + @app.callback( + Output('node-type-store', 'data'), + [Input('graph', 'clickData')] + ) + def update_node_type(clickData): + """ + Updates the node type based on the clicked node in the graph. + + Args: + clickData (dict): Data about the clicked node. + + Returns: + str: 'histogram' or 'heatmap' depending on the node type, or an empty string if no valid node is clicked. + """ + if clickData is None: + return '' + + try: + node_name = clickData['points'][0]['text'] + return 'histogram' if 'Attention Block' not in node_name else 'heatmap' + except (KeyError, IndexError, TypeError) as e: + print(f"Error processing clickData: {e}") + return '' + + # Callback to update the dropdown options based on the node type + @app.callback( + Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), + [Input('node-type-store', 'data')] + ) + def update_dropdown_options(node_type): + """ + Updates the options in the dropdown menu based on the node type. + + Args: + node_type (str): The type of the node ('histogram' or 'heatmap'). + + Returns: + list: List of dropdown options. + str: Default value for the dropdown. + """ + if node_type == 'histogram': + return [ + {'label': 'SMAPE', 'value': 'SMAPE_full'}, + {'label': 'Cossim', 'value': 'cossim_full'}, + {'label': 'Scale', 'value': 'scale_full'}, + {'label': 'MSE', 'value': 'mse_full'}, + ], 'cossim_full' + elif node_type == 'heatmap': + return [ + {'label': 'MSE Attn Heatmap', 'value': 'MSE Attn Heatmap'} + ], 'MSE Attn Heatmap' + else: + return [], None + + # Callback to update the node details graph based on selected metric and node type + @app.callback( + Output('node-details', 'figure'), + [Input('graph', 'clickData'), + Input('metric-dropdown', 'value')], + [State('node-type-store', 'data')] + ) + def display_node_data(clickData, selected_metric, node_type): + """ + Updates the node details graph based on the clicked node and selected metric. + + Args: + clickData (dict): Data about the clicked node. + selected_metric (str): The selected metric from the dropdown. + node_type (str): The type of the node ('histogram' or 'heatmap'). + + Returns: + go.Figure: The updated figure to display. + """ + if clickData is None: + print("No clickData received") + return go.Figure() + + try: + node_name = clickData['points'][0]['text'] + except (KeyError, IndexError, TypeError) as e: + print(f"Error processing clickData: {e}") + return go.Figure() + + fig = go.Figure() + if node_type == 'histogram': + trace = nn_graph.metric_handler.plot_node_hist(node_name, stat=selected_metric) + fig.add_trace(trace) + fig.update_layout( + title=f"Metrics for {node_name} | {selected_metric}", + xaxis_title="Metric", + yaxis_title="Value" + ) + elif node_type == 'heatmap': + trace = nn_graph.metric_handler.heatmap_plot(layer_name=node_name, stat='MSE Attn Heatmap') + fig.add_trace(trace) + fig.update_layout( + title=f"{node_name} | {selected_metric}", + xaxis_title="Model 1 Head", + yaxis_title="Model 0 Head" + ) + + return fig + + return app diff --git a/run_metrics.py b/run_metrics.py index 5ef8ee37..fba0badd 100644 --- a/run_metrics.py +++ b/run_metrics.py @@ -1,6 +1,5 @@ #%% OUTPUT_PATH = "./merged" # folder to store the result in -LORA_MERGE_CACHE = "/tmp" # change if you want to keep these for some reason CONFIG_YML = "./examples/linear_small.yml" # merge configuration file COPY_TOKENIZER = True # you want a tokenizer? yeah, that's what i thought LAZY_UNPICKLE = False # experimental low-memory model loader @@ -12,80 +11,26 @@ from mergekit.config import MergeConfiguration from mergekit.merge import MergeOptions from mergekit.measure import run_measure -from mergekit.plot_tools.plot_tools import ModelGraph -import dash -from dash import dcc, html -from dash.dependencies import Input, Output -import plotly.graph_objects as go - +from mergekit.plot_tools.plot_tools import ModelGraph, create_app +#%% with open(CONFIG_YML, "r", encoding="utf-8") as fp: - merge_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) - -out = run_measure( - merge_config, - out_path=OUTPUT_PATH, - options=MergeOptions( - lora_merge_cache=LORA_MERGE_CACHE, - cuda=torch.cuda.is_available(), - copy_tokenizer=COPY_TOKENIZER, - lazy_unpickle=LAZY_UNPICKLE, - low_cpu_memory=LOW_CPU_MEMORY, - ), -) - -nn_graph = ModelGraph([pair for pair in out if pair[1] is not None]) -nn_graph.construct_graph() + metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) -# Initialize the Dash app -app = dash.Dash(__name__) +if __name__ == '__main__': -app.layout = html.Div([ - html.Label('Model Architecture Graph | Hover for node stats | Click for node plots', style={'font-family': 'Arial'}), - dcc.Graph( - id='graph', - figure=nn_graph.plot_graph(), - ), - dcc.Graph(id='node-details'), - html.Label('Select Metric:', style={'font-family': 'Arial'}), - # Add a dropdown menu to select the metric - dcc.Dropdown( - id='metric-dropdown', - options=[ - {'label': 'SMAPE', 'value': 'SMAPE_full'}, - {'label': 'Cossim', 'value': 'cossim_full'}, - {'label': 'Scale', 'value': 'scale_full'}, - {'label': 'MSE', 'value': 'mse_full'}, - ], - value='cossim_full', - style={'font-family': 'Arial'} + out = run_measure( + metric_config, + out_path=OUTPUT_PATH, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=COPY_TOKENIZER, + lazy_unpickle=LAZY_UNPICKLE, + low_cpu_memory=LOW_CPU_MEMORY, + ), ) -]) - -@app.callback( - Output('node-details', 'figure'), - [Input('graph', 'clickData'), - Input('metric-dropdown', 'value')]) -def display_node_data(clickData, selected_metric): - if clickData is None: - print("No clickData received") - return go.Figure() - - try: - node_name = clickData['points'][0]['text'] - except (KeyError, IndexError, TypeError) as e: - print(f"Error processing clickData: {e}") - return go.Figure() - - fig = go.Figure() - fig.add_trace( - nn_graph.metric_handler.plot_node_hist(node_name, stat=selected_metric) - ) - fig.update_layout(title=f"Metrics for {node_name} | {selected_metric}", - xaxis_title="Metric", - yaxis_title="Value") + nn_graph = ModelGraph([pair for pair in out if pair[1] is not None]) + nn_graph.construct_graph() - return fig - -if __name__ == '__main__': + app = create_app(nn_graph=nn_graph) app.run_server() \ No newline at end of file From b30175eed51609986b1f3bba7302106b7de5544d Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 10 Jun 2024 18:05:26 +0100 Subject: [PATCH 09/64] refine implementation of attention metrics, add line plots to dashboard --- mergekit/architecture.py | 7 + mergekit/metric_methods/all_metrics.py | 119 ++++++++++++----- mergekit/plot_tools/plot_tools.py | 169 ++++++++++++------------- 3 files changed, 172 insertions(+), 123 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 653f1ac3..66270127 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -43,6 +43,10 @@ class WeightInfo(BaseModel, frozen=True): List of alternative names for the weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. + GQA_groups (Optional[int]): + Number of groups for GQA-style weight sharing, if applicable. + num_heads (Optional[int]): + Number of heads for multihead attention, if applicable. """ name: str @@ -53,6 +57,9 @@ class WeightInfo(BaseModel, frozen=True): aliases: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None + GQA_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA + num_heads: Optional[int] = None + class ProceduralSpaceInfo(BaseModel, frozen=True): """Defines a procedural space computed from one or more other spaces. diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index d9848013..ff317985 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -49,7 +49,7 @@ def SMAPE( hist_info = compute_histogram(smape, 100) return { - 'SMAPE_full': { + 'SMAPE Histogram': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] @@ -67,7 +67,7 @@ def cossim( hist_info = compute_histogram(cossim, 100) return { - 'cossim_full': { + 'cossim Histogram': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] @@ -87,7 +87,7 @@ def scale( hist_info = compute_histogram(scale_diff, 100) return { - 'scale_full': { + 'scale Histogram': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] @@ -97,55 +97,86 @@ def scale( } def mse( - tensors: List[torch.Tensor], **_kwargs + tensors: List[torch.Tensor], heatmap=False, **_kwargs ) -> torch.Tensor: + + res = {} + + if heatmap: + num_heads = tensors[0].shape[0] + heatmap = np.zeros((num_heads, num_heads)) + for i in range(num_heads): + for j in range(num_heads): + heatmap[i, j] = ((tensors[0][i] - tensors[1][j]) ** 2).mean().item() + res['MSE Attn Heatmap'] = heatmap squared_diff = (tensors[0] - tensors[1]) ** 2 mse_per_neuron = torch.mean(squared_diff, dim=1) hist_info = compute_histogram(mse_per_neuron, 100) - return { - 'mse_full': { + res.update({ + 'mse Histogram': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] }, 'mse_mean': mse_per_neuron.mean().item(), 'mse_std': mse_per_neuron.std().item() - } + }) + return res -def restructure_tensor(input_tensor, num_columns): +def ungroup_tensor(input_tensor, GQA_groups): """ - Restructure a tensor by splitting its columns. + Ungroup a tensor by repeating its columns. + + Args: + input_tensor (torch.Tensor): The input tensor to ungroup. + GQA_groups (int): The number of GQA groups. + + Returns: + torch.Tensor: The ungrouped tensor. + """ + rows, cols = input_tensor.shape + new_rows = rows * GQA_groups + ungrouped_tensor = torch.zeros(new_rows, cols) + + for i in range(GQA_groups): + ungrouped_tensor[i*rows:(i+1)*rows] = input_tensor[i].expand(rows, -1) + + return ungrouped_tensor + +def restructure_tensor(input_tensor, num_rows): + """ + Restructure a tensor by splitting its rows. Args: input_tensor (torch.Tensor): The input tensor to restructure. - num_columns (int): The number of columns for splitting. + num_rows (int): The number of rows for splitting. Returns: torch.Tensor: The restructured tensor. """ rows, cols = input_tensor.shape - new_cols = cols // num_columns - reshaped_tensor = input_tensor.view(rows, num_columns, new_cols) + new_rows = rows // num_rows + reshaped_tensor = input_tensor.view(new_rows, num_rows, cols) restructured_tensor = reshaped_tensor.permute(1, 0, 2) return restructured_tensor -def compare_attn_head_weights(k_proj, q_proj, v_proj, o_proj, num_heads, **_kwargs): - models = list(q_proj.keys()) - q_proj_0 = restructure_tensor(q_proj[models[0]], num_heads) - q_proj_1 = restructure_tensor(q_proj[models[1]], num_heads) +def group_attn_head_weights(k_proj, q_proj, v_proj, o_proj, weight_info): - # Now the first dimension is the head index, so can be compared pairwise or even cross compared within/between models. - heatmap = np.zeros((num_heads, num_heads)) - for i in range(num_heads): - for j in range(num_heads): - heatmap[i, j] = ((q_proj_0[i].flatten() - q_proj_1[j].flatten()) ** 2).mean().item() - - return { - 'MSE Attn Heatmap': heatmap, - } + num_heads = weight_info.num_heads + GQA_groups = weight_info.GQA_groups + + k_proj = ungroup_tensor(k_proj, GQA_groups) + v_proj = ungroup_tensor(v_proj, GQA_groups) + + k_proj = restructure_tensor(k_proj, num_heads) + v_proj = restructure_tensor(v_proj, num_heads) + q_proj = restructure_tensor(q_proj, num_heads) + o_proj = restructure_tensor(o_proj.T, num_heads) + + return k_proj, v_proj, q_proj, o_proj class AllMetricTask(Task[torch.Tensor]): gather_tensors: GatherTensors @@ -193,7 +224,25 @@ def execute( ) -> torch.Tensor: # Add metrics for attention weights res = {} - res.update(compare_attn_head_weights(k_proj, q_proj, v_proj, o_proj, num_heads=32)) # 32 is a placeholder + models = list(q_proj.keys()) + + k_proj_0, v_proj_0, q_proj_0, o_proj_0 = group_attn_head_weights(k_proj[models[0]], q_proj[models[0]], v_proj[models[0]], o_proj[models[0]], self.weight_info) + k_proj_1, v_proj_1, q_proj_1, o_proj_1 = group_attn_head_weights(k_proj[models[1]], q_proj[models[1]], v_proj[models[1]], o_proj[models[1]], self.weight_info) + + # Metrics for K, V, Q, O projections + + model_0_heads = torch.cat([k_proj_0, v_proj_0, q_proj_0, o_proj_0], dim=1) + model_1_heads = torch.cat([k_proj_1, v_proj_1, q_proj_1, o_proj_1], dim=1) + + # Metrics for heads + res.update(mse([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + heatmap=True)) + res.update(cossim([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + angular_distance=True)) + res.update(scale([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)])) return res @@ -202,14 +251,14 @@ def group_label(self) -> Optional[str]: return max([gather_tensor.group_label() for gather_tensor in list(self.weights.values())]) # Check this (X) def __hash__(self): - return hash((tuple(self.weight_infos),)) + return hash(self.weight_info) def __eq__(self, other): if not isinstance(other, AttnTask): return False - return self.weight_infos == other.weight_infos + return self.weight_info == other.weight_info -class blankTask(Task[torch.Tensor]): +class DummyTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo def uses_accelerator(self) -> bool: @@ -252,21 +301,23 @@ def make_task( ) -> Task: if 'self_attn' in output_weight.name: + # collect all attention weights for part in self.attn_parts: # also check only one key if part in output_weight.name: self.attn_weight_dict[part] = tensors self.attn_info_dict[part] = output_weight + # if all attention weights are collected, create attention task if set(list(self.attn_weight_dict.keys())) == set(self.attn_parts): - weights = self.attn_weight_dict - infos = self.attn_info_dict - self.attn_weight_dict = {} - self.attn_info_dict = {} + weights, infos = self.attn_weight_dict, self.attn_info_dict + self.attn_weight_dict, self.attn_info_dict = {}, {} weight_info = WeightInfo( name=f"Attention Block {self.block_count}", force_dtype=None, optional=False, aliases=None, + GQA_groups=4, + num_heads=32 ) self.block_count += 1 return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) @@ -277,6 +328,6 @@ def make_task( angular_distance=parameters["angular_distance"] ) else: - return blankTask(gather_tensors=tensors, weight_info=output_weight) + return DummyTask(gather_tensors=tensors, weight_info=output_weight) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 37c099de..9acb735d 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -65,6 +65,23 @@ def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): plt.savefig(save_to) plt.show() plt.close() + + def plotly_line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): + fig = go.Figure() + + y = [self.all_stats[layer]['metric'][stat] for layer in self.layer_names] + if f'{stat}'.replace('mean', 'std') in self.stat_names: + std_stat = f'{stat}'.replace('mean', 'std') + std_values = [self.all_stats[layer]['metric'].get(std_stat) for layer in self.layer_names] + + return go.Scatter( + x=self.layer_names, + y=y, + mode='lines+markers', + name='Line Plot' + ) + + def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): """ @@ -129,7 +146,7 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): def plot_node_hist(self, layer_name: str, stat: str): - bin_counts, bin_edges, bin_widths = self.all_stats[layer_name]['metric'][stat].values() + bin_counts, bin_edges, bin_widths = self.stats_at_layer(layer_name)[stat].values() return go.Bar( x=bin_edges[:-1], y=bin_counts, @@ -276,109 +293,68 @@ def plot_graph(self, save_to: str = None): def create_app(nn_graph): - """ - Creates and configures a Dash app to visualize metrics from a neural network graph. - - Args: - nn_graph (ModelGraph): An instance of the neural network graph to be visualized. - - Returns: - app (dash.Dash): Configured Dash app ready to be run. - """ - # Initialize the Dash app - app = dash.Dash(__name__) + app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) + # app = dash.Dash(__name__, external_stylesheets=['https://bootswatch.com/4/darkly/bootstrap.min.css']) - # Define the layout of the app app.layout = html.Div([ - dcc.Graph( - id='graph', - figure=nn_graph.plot_graph(), # Initial plot of the neural network graph - ), - dcc.Dropdown( - id='metric-dropdown', - options=[], # Options will be set dynamically based on the node type - style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} - ), - dcc.Graph(id='node-details'), # Placeholder for detailed metrics of a selected node - dcc.Store(id='node-type-store') # Store to keep track of the node type + html.Div([ + html.H1('Neural Network Similarity Vislualisation', style={'textAlign': 'center'}), + dcc.Graph( + id='graph', + figure=nn_graph.plot_graph(), + style={'width': '100%', 'height': '50vh'} + ), + ], className='container-fluid'), + + html.Div([ + html.H3('Node Metrics', style={'textAlign': 'center'}), + dcc.Dropdown( + id='metric-dropdown', + options=[], + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} + ), + dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '50vh'}), + ], className='container-fluid'), + + html.Div([ + html.H3('Metrics Across Layers', style={'textAlign': 'center'}), + dcc.Dropdown( + id='line-plot-dropdown', + options=[{'label': metric.replace('_', ' ').title(), 'value': metric} for metric in nn_graph.metric_handler.stat_names if 'mean' in metric], + value='cossim_mean', + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} + ), + dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}), + ], className='container-fluid'), ]) - # Callback to update the node type when a node in the graph is clicked @app.callback( - Output('node-type-store', 'data'), + Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), [Input('graph', 'clickData')] ) - def update_node_type(clickData): - """ - Updates the node type based on the clicked node in the graph. - - Args: - clickData (dict): Data about the clicked node. - - Returns: - str: 'histogram' or 'heatmap' depending on the node type, or an empty string if no valid node is clicked. - """ + def update_dropdown_options(clickData): if clickData is None: - return '' + return [] try: node_name = clickData['points'][0]['text'] - return 'histogram' if 'Attention Block' not in node_name else 'heatmap' + options = list(nn_graph.metric_handler.stats_at_layer(node_name).keys()) + options = [option for option in options if 'std' not in option] + options = [ + {'label': option.replace('_', ' ').capitalize(), 'value': option} for option in options if 'mean' not in option + ] + return options, options[0]['value'] + except (KeyError, IndexError, TypeError) as e: print(f"Error processing clickData: {e}") - return '' + return [] - # Callback to update the dropdown options based on the node type @app.callback( - Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), - [Input('node-type-store', 'data')] + Output('node-details-plot', 'figure'), + [Input('graph', 'clickData'), Input('metric-dropdown', 'value')], ) - def update_dropdown_options(node_type): - """ - Updates the options in the dropdown menu based on the node type. - - Args: - node_type (str): The type of the node ('histogram' or 'heatmap'). - - Returns: - list: List of dropdown options. - str: Default value for the dropdown. - """ - if node_type == 'histogram': - return [ - {'label': 'SMAPE', 'value': 'SMAPE_full'}, - {'label': 'Cossim', 'value': 'cossim_full'}, - {'label': 'Scale', 'value': 'scale_full'}, - {'label': 'MSE', 'value': 'mse_full'}, - ], 'cossim_full' - elif node_type == 'heatmap': - return [ - {'label': 'MSE Attn Heatmap', 'value': 'MSE Attn Heatmap'} - ], 'MSE Attn Heatmap' - else: - return [], None - - # Callback to update the node details graph based on selected metric and node type - @app.callback( - Output('node-details', 'figure'), - [Input('graph', 'clickData'), - Input('metric-dropdown', 'value')], - [State('node-type-store', 'data')] - ) - def display_node_data(clickData, selected_metric, node_type): - """ - Updates the node details graph based on the clicked node and selected metric. - - Args: - clickData (dict): Data about the clicked node. - selected_metric (str): The selected metric from the dropdown. - node_type (str): The type of the node ('histogram' or 'heatmap'). - - Returns: - go.Figure: The updated figure to display. - """ + def display_node_data(clickData, selected_metric): if clickData is None: - print("No clickData received") return go.Figure() try: @@ -388,7 +364,7 @@ def display_node_data(clickData, selected_metric, node_type): return go.Figure() fig = go.Figure() - if node_type == 'histogram': + if 'histogram' in selected_metric or 'Histogram' in selected_metric: trace = nn_graph.metric_handler.plot_node_hist(node_name, stat=selected_metric) fig.add_trace(trace) fig.update_layout( @@ -396,7 +372,7 @@ def display_node_data(clickData, selected_metric, node_type): xaxis_title="Metric", yaxis_title="Value" ) - elif node_type == 'heatmap': + elif 'heatmap' in selected_metric or 'Heatmap' in selected_metric: trace = nn_graph.metric_handler.heatmap_plot(layer_name=node_name, stat='MSE Attn Heatmap') fig.add_trace(trace) fig.update_layout( @@ -407,4 +383,19 @@ def display_node_data(clickData, selected_metric, node_type): return fig + @app.callback( + Output('line-plot', 'figure'), + [Input('line-plot-dropdown', 'value')] + ) + def update_line_plot(selected_metric): + fig = go.Figure() + stat_values = [nn_graph.metric_handler.stats_at_layer(layer)[selected_metric] for layer in nn_graph.metric_handler.layer_names] + fig.add_trace(go.Scatter(x=nn_graph.metric_handler.layer_names, y=stat_values, mode='lines+markers')) + fig.update_layout( + title=f"{selected_metric.replace('_', ' ').capitalize()} Across Layers", + xaxis_title="Layer", + yaxis_title=selected_metric.replace('_', ' ').capitalize() + ) + return fig + return app From dfc0603ff44bd02d70d5414d0727928dde93ec9f Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Tue, 11 Jun 2024 11:57:19 +0100 Subject: [PATCH 10/64] More restructuring, more seamless integration of attention and mlp layers. heatmaps for other metrics. --- mergekit/graph.py | 1 + mergekit/measure.py | 2 +- mergekit/metric_methods/all_metrics.py | 204 ++++++++++++++----------- mergekit/plot_tools/plot_tools.py | 47 +++--- run_metrics.py | 6 +- 5 files changed, 150 insertions(+), 110 deletions(-) diff --git a/mergekit/graph.py b/mergekit/graph.py index c81cb85b..54db5275 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -37,6 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True): Abstract base class representing a task in a computational graph. This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. + Note that PyDantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. Attributes: Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution. diff --git a/mergekit/measure.py b/mergekit/measure.py index 95e78028..0c18a371 100644 --- a/mergekit/measure.py +++ b/mergekit/measure.py @@ -18,7 +18,7 @@ import tqdm import transformers -from mergekit.architecture import ArchitectureInfo, get_architecture_info +from mergekit.architecture import get_architecture_info from mergekit.config import MergeConfiguration from mergekit.graph import Executor from mergekit.io.tasks import LoaderCache diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index ff317985..c0dacaf0 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -16,33 +16,95 @@ from typing import Any, Dict, List, Optional from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference, ImmutableMap +from mergekit.common import ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors, LoadTensor +from mergekit.io.tasks import GatherTensors from mergekit.metric_methods.base import MetricMethod -from mergekit.metric_methods.base import ConfigParameterDef import torch import torch.nn.functional as F import numpy as np -def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: - bin_counts, bin_edges = np.histogram(tensor.numpy(), bins=n_bins) - bin_widths = np.diff(bin_edges) - return bin_counts, bin_edges, bin_widths - def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): + """Validate tensor shapes and count.""" unique_shapes = set(t.shape for t in tensors) if len(unique_shapes) != 1: raise RuntimeError(f"Tensor size mismatch for {weight_info.name}, sizes: {list(unique_shapes)}") if expected_tensors: if len(tensors) != expected_tensors: raise RuntimeError(f"Expected {expected_tensors} tensors, got {len(tensors)}") + +def ungroup_tensor(input_tensor: torch.Tensor, GQA_groups: int) -> torch.Tensor: + """ + Ungroup a grouped tensor by repeating its rows. + """ + rows, cols = input_tensor.shape + new_rows = rows * GQA_groups + ungrouped_tensor = torch.zeros(new_rows, cols) + + for i in range(GQA_groups): + ungrouped_tensor[i*rows:(i+1)*rows] = input_tensor[i].expand(rows, -1) + + return ungrouped_tensor + +def restructure_tensor(input_tensor: torch.Tensor, num_rows: int) -> torch.Tensor: + """ + Restructure a tensor by splitting its rows and permuting the dimensions. + + This is used so that the attention weights can be grouped by head in the first dimension. + """ + rows, cols = input_tensor.shape + new_rows = rows // num_rows + reshaped_tensor = input_tensor.view(new_rows, num_rows, cols) + restructured_tensor = reshaped_tensor.permute(1, 0, 2) + + return restructured_tensor + +def group_attn_head_weights(k_proj: torch.Tensor, + q_proj: torch.Tensor, + v_proj: torch.Tensor, + o_proj: torch.Tensor, + weight_info: WeightInfo) -> tuple[torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor]: + + num_heads = weight_info.num_heads + GQA_groups = weight_info.GQA_groups + + k_proj = ungroup_tensor(k_proj, GQA_groups) + v_proj = ungroup_tensor(v_proj, GQA_groups) + + k_proj = restructure_tensor(k_proj, num_heads) + v_proj = restructure_tensor(v_proj, num_heads) + q_proj = restructure_tensor(q_proj, num_heads) + o_proj = restructure_tensor(o_proj.T, num_heads) # Output weights are split into heads by rows, not columns + + return k_proj, v_proj, q_proj, o_proj + +def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: + bin_counts, bin_edges = np.histogram(tensor.numpy(), bins=n_bins) + bin_widths = np.diff(bin_edges) + return bin_counts, bin_edges, bin_widths + +def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + # Normalize the rows of both matrices + A_norm = A / A.norm(dim=1, keepdim=True) + B_norm = B / B.norm(dim=1, keepdim=True) + + # Compute the cosine similarity matrix + similarity_matrix = torch.mm(A_norm, B_norm.t()) + return similarity_matrix + +# Metric functions + def SMAPE( tensors: List[torch.Tensor], **_kwargs ) -> Dict[str, Any]: + """Symmetric Mean Absolute Percentage Error (SMAPE).""" + numerator = torch.abs(tensors[0] - tensors[1]) denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) smape = torch.mean(torch.div(numerator, denominator), dim=1) @@ -59,14 +121,21 @@ def SMAPE( } def cossim( - tensors: List[torch.Tensor], **_kwargs + tensors: List[torch.Tensor], return_heatmap=False, **_kwargs ) -> torch.Tensor: + """Cosine similarity""" cossim = F.cosine_similarity(tensors[0], tensors[1], dim=1) - if _kwargs.get('angular_distance'): - cossim = torch.acos(cossim.clamp(min=-1, max=1))/torch.pi + + res = {} + + if return_heatmap: + res.update({'Cossim Heatmap': cossim_heatmap(tensors[0], tensors[1])}) + + assert torch.isclose(cossim, cossim, atol=1e-6).all(), "NaNs in cosine similarity" + assert torch.isclose(cossim, cossim_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-4).all(), "Diagonal elements of cosine similarity matrix do not match" hist_info = compute_histogram(cossim, 100) - return { + res.update({ 'cossim Histogram': { 'count': hist_info[0], 'edges': hist_info[1], @@ -74,19 +143,34 @@ def cossim( }, 'cossim_mean': cossim.mean().item(), 'cossim_std': cossim.std().item() - } + }) + return res def scale( - tensors: List[torch.Tensor], **_kwargs + tensors: List[torch.Tensor], return_heatmap=False, **_kwargs ) -> torch.Tensor: + """ + Scale difference: ratio of absolute difference to average scale. + Complementary to cosine similarity, which measures the angle between two vectors and is invariant to scale. + """ norm_0 = torch.norm(tensors[0], dim=1) norm_1 = torch.norm(tensors[1], dim=1) + res = {} + + if return_heatmap: + num_heads = tensors[0].shape[0] + heatmap = np.zeros((num_heads, num_heads)) + for i in range(num_heads): + for j in range(num_heads): + heatmap[i, j] = torch.abs(norm_0[i] - norm_1[j]) / ((norm_0[i] + norm_1[j]) / 2) + res.update({'Scale Heatmap': heatmap}) + scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) hist_info = compute_histogram(scale_diff, 100) - return { + res.update({ 'scale Histogram': { 'count': hist_info[0], 'edges': hist_info[1], @@ -94,15 +178,16 @@ def scale( }, 'scale_mean': scale_diff.mean().item(), 'scale_std': scale_diff.std().item() - } + }) + return res def mse( - tensors: List[torch.Tensor], heatmap=False, **_kwargs + tensors: List[torch.Tensor], return_heatmap: bool =False, **_kwargs ) -> torch.Tensor: - + """Mean squared error (MSE).""" res = {} - if heatmap: + if return_heatmap: num_heads = tensors[0].shape[0] heatmap = np.zeros((num_heads, num_heads)) for i in range(num_heads): @@ -125,63 +210,9 @@ def mse( }) return res -def ungroup_tensor(input_tensor, GQA_groups): - """ - Ungroup a tensor by repeating its columns. - - Args: - input_tensor (torch.Tensor): The input tensor to ungroup. - GQA_groups (int): The number of GQA groups. - - Returns: - torch.Tensor: The ungrouped tensor. - """ - rows, cols = input_tensor.shape - new_rows = rows * GQA_groups - ungrouped_tensor = torch.zeros(new_rows, cols) - - for i in range(GQA_groups): - ungrouped_tensor[i*rows:(i+1)*rows] = input_tensor[i].expand(rows, -1) - - return ungrouped_tensor - -def restructure_tensor(input_tensor, num_rows): - """ - Restructure a tensor by splitting its rows. - - Args: - input_tensor (torch.Tensor): The input tensor to restructure. - num_rows (int): The number of rows for splitting. - - Returns: - torch.Tensor: The restructured tensor. - """ - rows, cols = input_tensor.shape - new_rows = rows // num_rows - reshaped_tensor = input_tensor.view(new_rows, num_rows, cols) - restructured_tensor = reshaped_tensor.permute(1, 0, 2) - - return restructured_tensor - -def group_attn_head_weights(k_proj, q_proj, v_proj, o_proj, weight_info): - - num_heads = weight_info.num_heads - GQA_groups = weight_info.GQA_groups - - k_proj = ungroup_tensor(k_proj, GQA_groups) - v_proj = ungroup_tensor(v_proj, GQA_groups) - - k_proj = restructure_tensor(k_proj, num_heads) - v_proj = restructure_tensor(v_proj, num_heads) - q_proj = restructure_tensor(q_proj, num_heads) - o_proj = restructure_tensor(o_proj.T, num_heads) - - return k_proj, v_proj, q_proj, o_proj - -class AllMetricTask(Task[torch.Tensor]): +class MLPTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo - angular_distance: bool def uses_accelerator(self) -> bool: return True @@ -197,7 +228,7 @@ def execute( res = {} if 'mlp' in self.weight_info.name: - res.update(cossim(tensors, angular_distance=self.angular_distance)) + res.update(cossim(tensors)) res.update(SMAPE(tensors)) res.update(scale(tensors)) res.update(mse(tensors)) @@ -237,11 +268,14 @@ def execute( # Metrics for heads res.update(mse([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)], - heatmap=True)) + return_heatmap=True)) res.update(cossim([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)], - angular_distance=True)) + return_heatmap=True)) res.update(scale([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True)) + res.update(SMAPE([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)])) return res @@ -285,12 +319,8 @@ class AllMetric(MetricMethod): attn_info_dict: Optional[Dict[str, WeightInfo]] = {} - attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] + attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] # hard-coded for now block_count: Optional[int] = 0 - def parameters(self) -> List[ConfigParameterDef]: - return [ - ConfigParameterDef(name="angular_distance", required=False, default_value=False), - ] def make_task( self, *, @@ -316,18 +346,18 @@ def make_task( force_dtype=None, optional=False, aliases=None, - GQA_groups=4, - num_heads=32 + GQA_groups=4, # hard-coded for now + num_heads=32 # hard-coded for now ) self.block_count += 1 return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) if 'mlp' in output_weight.name: - return AllMetricTask( + return MLPTask( gather_tensors=tensors, weight_info=output_weight, - angular_distance=parameters["angular_distance"] ) else: - return DummyTask(gather_tensors=tensors, weight_info=output_weight) + # Executor expects a task to be returned + return DummyTask(gather_tensors=tensors, weight_info=output_weight) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 9acb735d..60719a87 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import dash from dash import dcc, html -from dash.dependencies import Input, Output, State +from dash.dependencies import Input, Output class MetricsHandler: """ @@ -81,8 +81,6 @@ def plotly_line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): name='Line Plot' ) - - def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): """ Plot the stat values with optional error bars. @@ -137,7 +135,7 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): ax.set_ylabel(kwargs.get('ylabel', stat)) ax.set_xticks(np.arange(len(self.layer_names))) ax.set_xticklabels(self.layer_names, rotation=45) - ax.set_title(kwargs.get('title', f'{stat.replace("_", " ").capitalize()}')) + ax.set_title(kwargs.get('title', f'{stat.replace("_", " ").title()}')) # Set additional attributes for kwarg in ax_kwargs: @@ -209,7 +207,7 @@ def _add_nodes_and_edges(self, hierarchy): def construct_graph(self): self._add_nodes_and_edges(self.hierarchy) - def plot_graph(self, save_to: str = None): + def plot_graph(self, colour_by='cossim_mean', save_to: str = None): """ Plot the graph using Plotly for interactivity. """ @@ -248,14 +246,13 @@ def plot_graph(self, save_to: str = None): for metric in metrics_to_plot: if metric in metric_values: value = metric_values[metric] - hover += f"
{metric.replace('_', ' ').capitalize()}: {value:.4f}{'%' if 'SMAPE' in metric else ''}" + hover += f"
{metric.replace('_', ' ').title()}: {value:.4f}{'%' if 'SMAPE' in metric else ''}" node_values[metric].append(value) node_text.append(node) hover_text.append(hover) - first_metric = metrics_to_plot[0] - node_colors = [value for value in node_values[first_metric]] + node_colors = [value for value in node_values[colour_by]] node_trace = go.Scatter( x=node_x, y=node_y, @@ -268,12 +265,12 @@ def plot_graph(self, save_to: str = None): showscale=True, colorscale='Viridis', color=node_colors, - cmin=min(node_values[first_metric]), - cmax=max(node_values[first_metric]), + cmin=min(node_values[colour_by]), + cmax=max(node_values[colour_by]), size=10, colorbar=dict( thickness=15, - title=first_metric.replace('_', ' ').capitalize(), + title=colour_by.replace('_', ' ').title(), xanchor='left', titleside='right', ), @@ -294,11 +291,16 @@ def plot_graph(self, save_to: str = None): def create_app(nn_graph): app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) - # app = dash.Dash(__name__, external_stylesheets=['https://bootswatch.com/4/darkly/bootstrap.min.css']) app.layout = html.Div([ html.Div([ - html.H1('Neural Network Similarity Vislualisation', style={'textAlign': 'center'}), + html.H1('Network Weights Similarity Visualisation', style={'textAlign': 'center', 'padding': '20px'}), + dcc.Dropdown( + id='colour-by-dropdown', + options=[{'label': metric.replace('_', ' ').title(), 'value': metric} for metric in nn_graph.metric_handler.stat_names if 'mean' in metric], + value='cossim_mean', + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} + ), dcc.Graph( id='graph', figure=nn_graph.plot_graph(), @@ -313,7 +315,7 @@ def create_app(nn_graph): options=[], style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} ), - dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '50vh'}), + dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}), ], className='container-fluid'), html.Div([ @@ -332,7 +334,7 @@ def create_app(nn_graph): Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), [Input('graph', 'clickData')] ) - def update_dropdown_options(clickData): + def update_metric_dropdown_options(clickData): if clickData is None: return [] @@ -341,7 +343,7 @@ def update_dropdown_options(clickData): options = list(nn_graph.metric_handler.stats_at_layer(node_name).keys()) options = [option for option in options if 'std' not in option] options = [ - {'label': option.replace('_', ' ').capitalize(), 'value': option} for option in options if 'mean' not in option + {'label': option.replace('_', ' ').title(), 'value': option} for option in options if 'mean' not in option ] return options, options[0]['value'] @@ -373,7 +375,7 @@ def display_node_data(clickData, selected_metric): yaxis_title="Value" ) elif 'heatmap' in selected_metric or 'Heatmap' in selected_metric: - trace = nn_graph.metric_handler.heatmap_plot(layer_name=node_name, stat='MSE Attn Heatmap') + trace = nn_graph.metric_handler.heatmap_plot(layer_name=node_name, stat=selected_metric) fig.add_trace(trace) fig.update_layout( title=f"{node_name} | {selected_metric}", @@ -392,10 +394,17 @@ def update_line_plot(selected_metric): stat_values = [nn_graph.metric_handler.stats_at_layer(layer)[selected_metric] for layer in nn_graph.metric_handler.layer_names] fig.add_trace(go.Scatter(x=nn_graph.metric_handler.layer_names, y=stat_values, mode='lines+markers')) fig.update_layout( - title=f"{selected_metric.replace('_', ' ').capitalize()} Across Layers", + title=f"{selected_metric.replace('_', ' ').title()} Across Layers", xaxis_title="Layer", - yaxis_title=selected_metric.replace('_', ' ').capitalize() + yaxis_title=selected_metric.replace('_', ' ').title() ) return fig + @app.callback( + Output('graph', 'figure'), + [Input('colour-by-dropdown', 'value')] + ) + def update_graph_colour(colour_by): + return nn_graph.plot_graph(colour_by=colour_by) + return app diff --git a/run_metrics.py b/run_metrics.py index fba0badd..cc5c3f07 100644 --- a/run_metrics.py +++ b/run_metrics.py @@ -15,10 +15,10 @@ #%% with open(CONFIG_YML, "r", encoding="utf-8") as fp: metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) - +#%% if __name__ == '__main__': - out = run_measure( + metrics = run_measure( metric_config, out_path=OUTPUT_PATH, options=MergeOptions( @@ -29,7 +29,7 @@ ), ) - nn_graph = ModelGraph([pair for pair in out if pair[1] is not None]) + nn_graph = ModelGraph(metrics) nn_graph.construct_graph() app = create_app(nn_graph=nn_graph) From f88f9042da84ead44d23755640a7f6a54765faad Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 12 Jun 2024 15:17:31 +0100 Subject: [PATCH 11/64] vectorise heatmap computation --- examples/linear_small.yml | 6 +--- mergekit/metric_methods/all_metrics.py | 42 ++++++++++++++++---------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/examples/linear_small.yml b/examples/linear_small.yml index add503f2..3f13c4d4 100644 --- a/examples/linear_small.yml +++ b/examples/linear_small.yml @@ -1,10 +1,6 @@ models: - model: BEE-spoke-data/smol_llama-220M-GQA - - model: BEE-spoke-data/smol_llama-220M-openhermes - # - model: psmathur/orca_mini_v3_13b - # - model: garage-bAInd/Platypus2-13B - -metric_method: scale +metric_method: all dtype: float32 diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index c0dacaf0..d0b6f2d5 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -159,15 +159,18 @@ def scale( res = {} + scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) + if return_heatmap: - num_heads = tensors[0].shape[0] - heatmap = np.zeros((num_heads, num_heads)) - for i in range(num_heads): - for j in range(num_heads): - heatmap[i, j] = torch.abs(norm_0[i] - norm_1[j]) / ((norm_0[i] + norm_1[j]) / 2) + norm_0 = norm_0.unsqueeze(1) # shape becomes [num_heads, 1] + norm_1 = norm_1.unsqueeze(0) # shape becomes [1, num_heads] + + # Compute the scale difference between each pair of heads by broadcasting + heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) res.update({'Scale Heatmap': heatmap}) - scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) + assert torch.isclose(scale_diff, heatmap.diagonal(), atol=1e-4).all(), "Diagonal elements of scale difference matrix do not match" + hist_info = compute_histogram(scale_diff, 100) res.update({ @@ -188,11 +191,16 @@ def mse( res = {} if return_heatmap: - num_heads = tensors[0].shape[0] - heatmap = np.zeros((num_heads, num_heads)) - for i in range(num_heads): - for j in range(num_heads): - heatmap[i, j] = ((tensors[0][i] - tensors[1][j]) ** 2).mean().item() + # Expand dimensions for broadcasting + tensors_0_exp = tensors[0].unsqueeze(1) # shape becomes [num_heads, 1, ...] + tensors_1_exp = tensors[1].unsqueeze(0) # shape becomes [1, num_heads, ...] + + # Compute squared differences + diffs = (tensors_0_exp - tensors_1_exp) ** 2 + + # Compute mean over all dimensions except the first two + heatmap = diffs.mean(dim=tuple(range(2, diffs.dim()))).numpy() + res['MSE Attn Heatmap'] = heatmap squared_diff = (tensors[0] - tensors[1]) ** 2 @@ -210,6 +218,8 @@ def mse( }) return res +# Tasks + class MLPTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo @@ -228,10 +238,10 @@ def execute( res = {} if 'mlp' in self.weight_info.name: - res.update(cossim(tensors)) + res.update(cossim(tensors, return_heatmap=True)) res.update(SMAPE(tensors)) - res.update(scale(tensors)) - res.update(mse(tensors)) + res.update(scale(tensors, return_heatmap=True)) + res.update(mse(tensors, return_heatmap=False)) # Highly inefficient return res @@ -309,8 +319,9 @@ def execute( def group_label(self) -> Optional[str]: return self.gather_tensors.group_label() - +# Metric method + class AllMetric(MetricMethod): attn_weight_tensors: Optional[list] = [] attn_weight_infos: Optional[list] = [] @@ -318,7 +329,6 @@ class AllMetric(MetricMethod): attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} attn_info_dict: Optional[Dict[str, WeightInfo]] = {} - attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] # hard-coded for now block_count: Optional[int] = 0 def make_task( From f89df57a50e6011e1b9880e5f53ba370d407f1a7 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Fri, 14 Jun 2024 16:48:49 +0100 Subject: [PATCH 12/64] Address issue with lexicographical sort by adding leading zeros to layer name --- mergekit/graph.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mergekit/graph.py b/mergekit/graph.py index 54db5275..79924c4c 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -107,6 +107,7 @@ def uses_accelerator(self) -> bool: """ return False +import re class Executor: """ @@ -242,13 +243,20 @@ def _make_schedule(self, targets: List[Task]) -> List[Task]: # they will be included in the final schedule edge_tups.append((Executor.DUMMY_TASK_VALUE, task)) + def _pad_numbers(s): + parts = s.split('.') + for i, part in enumerate(parts): + if part.isdigit(): + parts[i] = part.zfill(3) + return '.'.join(parts) + def _compare_key(task: Union[Task, str]): if task == Executor.DUMMY_TASK_VALUE: return ("", 0) - return ( - task.group_label() or "", - -task.priority(), - ) + group_label = task.group_label() or "" + padded_label = _pad_numbers(group_label) + priority = -task.priority() + return (padded_label, priority) graph = networkx.DiGraph(edge_tups) res = [ From 74c5d33ea8ec2a3eee8a184bc61eca15d6e2142d Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Fri, 14 Jun 2024 16:51:08 +0100 Subject: [PATCH 13/64] remove unnecessary import from last commit --- mergekit/graph.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mergekit/graph.py b/mergekit/graph.py index 79924c4c..d7c11933 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -107,8 +107,6 @@ def uses_accelerator(self) -> bool: """ return False -import re - class Executor: """ Schedules and executes a set of tasks and their dependencies. From 7c209d259406f4628bf3025286001281ecb12f19 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 11:46:48 +0100 Subject: [PATCH 14/64] add validation check to ensure MergeConfiguration method is either Merge OR Metri, not both. --- mergekit/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mergekit/config.py b/mergekit/config.py index f4180e29..b5863798 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -109,6 +109,14 @@ def validate_inputs(self): if ((not self.slices) and (not self.models)) or (self.slices and self.models): raise RuntimeError("Must specify either output slices or models to merge") return self + + @model_validator(mode="after") + def validate_methods(self): + if not self.merge_method and not self.metric_method: + raise ValueError("Either 'merge_method' or 'metric_method' must be provided.") + if self.merge_method and self.metric_method: + raise ValueError("Only one of 'merge_method' or 'metric_method' can be provided, not both.") + return self def to_yaml(self) -> str: return yaml.dump( From db65f83a72af4786a5cda2b511c55a9d39975b1f Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 11:47:44 +0100 Subject: [PATCH 15/64] moved run_metrics and use click to enable commandline control over arguments --- mergekit/scripts/run_metrics.py | 38 +++++++++++++++++++++++++++++++++ run_metrics.py | 36 ------------------------------- 2 files changed, 38 insertions(+), 36 deletions(-) create mode 100644 mergekit/scripts/run_metrics.py delete mode 100644 run_metrics.py diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py new file mode 100644 index 00000000..cf6976ef --- /dev/null +++ b/mergekit/scripts/run_metrics.py @@ -0,0 +1,38 @@ +import click +import torch +import yaml + +from mergekit.config import MergeConfiguration +from mergekit.merge import MergeOptions +from mergekit.measure import run_measure +from mergekit.plot_tools.plot_tools import ModelGraph, create_app + +@click.command() +@click.option('--output_path', default="./merged", help='folder to store the result in.') +@click.option('--config_yml', default="./examples/metrics-small.yml", help='merge configuration file.') +@click.option('--copy_tokenizer', default=True, help='') +@click.option('--lazy_unpickle', default=False, help='experimental low-memory model loader.') +@click.option('--low_cpu_memory', default=False, help='enable if you somehow have more VRAM than RAM+swap') +def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory): + with open(config_yml, "r", encoding="utf-8") as fp: + metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) + + metrics = run_measure( + metric_config, + out_path=output_path, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=copy_tokenizer, + lazy_unpickle=lazy_unpickle, + low_cpu_memory=low_cpu_memory, + ), + ) + + nn_graph = ModelGraph(metrics) + nn_graph.construct_graph() + + app = create_app(nn_graph=nn_graph) + app.run_server() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/run_metrics.py b/run_metrics.py deleted file mode 100644 index cc5c3f07..00000000 --- a/run_metrics.py +++ /dev/null @@ -1,36 +0,0 @@ -#%% -OUTPUT_PATH = "./merged" # folder to store the result in -CONFIG_YML = "./examples/linear_small.yml" # merge configuration file -COPY_TOKENIZER = True # you want a tokenizer? yeah, that's what i thought -LAZY_UNPICKLE = False # experimental low-memory model loader -LOW_CPU_MEMORY = False # enable if you somehow have more VRAM than RAM+swap - -import torch -import yaml - -from mergekit.config import MergeConfiguration -from mergekit.merge import MergeOptions -from mergekit.measure import run_measure -from mergekit.plot_tools.plot_tools import ModelGraph, create_app -#%% -with open(CONFIG_YML, "r", encoding="utf-8") as fp: - metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) -#%% -if __name__ == '__main__': - - metrics = run_measure( - metric_config, - out_path=OUTPUT_PATH, - options=MergeOptions( - cuda=torch.cuda.is_available(), - copy_tokenizer=COPY_TOKENIZER, - lazy_unpickle=LAZY_UNPICKLE, - low_cpu_memory=LOW_CPU_MEMORY, - ), - ) - - nn_graph = ModelGraph(metrics) - nn_graph.construct_graph() - - app = create_app(nn_graph=nn_graph) - app.run_server() \ No newline at end of file From 3fe66a8d51b5919210f60f2af79fed9b23a10b7f Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 11:48:55 +0100 Subject: [PATCH 16/64] rename example config --- examples/linear_small.yml | 6 ------ examples/metrics-1.yml | 6 ++++++ 2 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 examples/linear_small.yml create mode 100644 examples/metrics-1.yml diff --git a/examples/linear_small.yml b/examples/linear_small.yml deleted file mode 100644 index 3f13c4d4..00000000 --- a/examples/linear_small.yml +++ /dev/null @@ -1,6 +0,0 @@ -models: - - model: BEE-spoke-data/smol_llama-220M-GQA - - model: BEE-spoke-data/smol_llama-220M-openhermes - -metric_method: all -dtype: float32 diff --git a/examples/metrics-1.yml b/examples/metrics-1.yml new file mode 100644 index 00000000..58641b85 --- /dev/null +++ b/examples/metrics-1.yml @@ -0,0 +1,6 @@ +models: + - model: huggyllama/llama-7b + - model: TheBloke/Llama-2-7B-fp16 + +metric_method: all +dtype: float32 From 62692d1e81620a80e4bbbbb438025dae10695a64 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 11:50:06 +0100 Subject: [PATCH 17/64] correct case for gqa_group name --- mergekit/architecture.py | 4 ++-- mergekit/metric_methods/all_metrics.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 66270127..f0efd329 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -43,7 +43,7 @@ class WeightInfo(BaseModel, frozen=True): List of alternative names for the weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. - GQA_groups (Optional[int]): + gqa_groups (Optional[int]): Number of groups for GQA-style weight sharing, if applicable. num_heads (Optional[int]): Number of heads for multihead attention, if applicable. @@ -57,7 +57,7 @@ class WeightInfo(BaseModel, frozen=True): aliases: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None - GQA_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA + gqa_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA num_heads: Optional[int] = None diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index d0b6f2d5..d086140d 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -35,15 +35,15 @@ def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expec if len(tensors) != expected_tensors: raise RuntimeError(f"Expected {expected_tensors} tensors, got {len(tensors)}") -def ungroup_tensor(input_tensor: torch.Tensor, GQA_groups: int) -> torch.Tensor: +def ungroup_tensor(input_tensor: torch.Tensor, gqa_groups: int) -> torch.Tensor: """ Ungroup a grouped tensor by repeating its rows. """ rows, cols = input_tensor.shape - new_rows = rows * GQA_groups + new_rows = rows * gqa_groups ungrouped_tensor = torch.zeros(new_rows, cols) - for i in range(GQA_groups): + for i in range(gqa_groups): ungrouped_tensor[i*rows:(i+1)*rows] = input_tensor[i].expand(rows, -1) return ungrouped_tensor @@ -71,10 +71,10 @@ def group_attn_head_weights(k_proj: torch.Tensor, torch.Tensor]: num_heads = weight_info.num_heads - GQA_groups = weight_info.GQA_groups + gqa_groups = weight_info.gqa_groups - k_proj = ungroup_tensor(k_proj, GQA_groups) - v_proj = ungroup_tensor(v_proj, GQA_groups) + k_proj = ungroup_tensor(k_proj, gqa_groups) + v_proj = ungroup_tensor(v_proj, gqa_groups) k_proj = restructure_tensor(k_proj, num_heads) v_proj = restructure_tensor(v_proj, num_heads) @@ -356,7 +356,7 @@ def make_task( force_dtype=None, optional=False, aliases=None, - GQA_groups=4, # hard-coded for now + gqa_groups=4, # hard-coded for now num_heads=32 # hard-coded for now ) self.block_count += 1 From 2a0c52093859bdb4c60352783c232fa5002a7034 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 12:11:35 +0100 Subject: [PATCH 18/64] replace measure with merge + early out --- mergekit/measure.py | 92 --------------------------------- mergekit/merge.py | 6 +++ mergekit/scripts/run_metrics.py | 4 +- 3 files changed, 8 insertions(+), 94 deletions(-) delete mode 100644 mergekit/measure.py diff --git a/mergekit/measure.py b/mergekit/measure.py deleted file mode 100644 index 0c18a371..00000000 --- a/mergekit/measure.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -import logging - -import tqdm -import transformers - -from mergekit.architecture import get_architecture_info -from mergekit.config import MergeConfiguration -from mergekit.graph import Executor -from mergekit.io.tasks import LoaderCache -from mergekit.options import MergeOptions -from mergekit.plan import MergePlanner -from mergekit.merge import _model_out_config - - -def run_measure( - merge_config: MergeConfiguration, - out_path: str, - options: MergeOptions, -): - if options.random_seed is not None: - transformers.trainer_utils.set_seed(options.random_seed) - - if not merge_config.models and not merge_config.slices: - raise RuntimeError("No output requested") - - model_arch_info = [ - get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) - for m in merge_config.referenced_models() - ] - if not options.allow_crimes: - if not all(a == model_arch_info[0] for a in model_arch_info[1:]): - raise RuntimeError( - "Must specify --allow-crimes to attempt to mix different architectures" - ) - arch_info = model_arch_info[0] - - # initialize loader cache and set options - loader_cache = LoaderCache() - loader_cache.setup(options=options) - - # create config for output model - cfg_out = _model_out_config( - merge_config, arch_info, trust_remote_code=options.trust_remote_code - ) - - # warm up loader cache - for model in ( - pbar := tqdm.tqdm( - merge_config.referenced_models(), - desc="Warmup loader cache", - disable=options.quiet, - ) - ): - loader_cache.get(model) - del pbar - - logging.info("Planning operations") - targets = MergePlanner( - merge_config, - arch_info, - options=options, - out_model_config=cfg_out, - ).plan_to_disk(out_path=out_path) - - exec = Executor( - tasks=targets, - math_device="cuda" if options.cuda else "cpu", - storage_device="cuda" if options.low_cpu_memory else "cpu", - ) - - res = [] - for _task, value in exec.run(quiet=options.quiet): - res.append((_task, value)) - - return res - -__all__ = ["MergeOptions", "run_merge"] diff --git a/mergekit/merge.py b/mergekit/merge.py index d045644c..008fc55a 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -88,11 +88,17 @@ def run_merge( storage_device="cuda" if options.low_cpu_memory else "cpu", ) + metrics_out = [] tokenizer = None for _task, value in exec.run(quiet=options.quiet): + if merge_config.metric_method is not None: + metrics_out.append((_task, value)) if isinstance(value, TokenizerInfo): tokenizer = value.tokenizer + if metrics_out: + return metrics_out + if tokenizer: _update_config_vocab(cfg_out, tokenizer) diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py index cf6976ef..1090e63e 100644 --- a/mergekit/scripts/run_metrics.py +++ b/mergekit/scripts/run_metrics.py @@ -4,7 +4,7 @@ from mergekit.config import MergeConfiguration from mergekit.merge import MergeOptions -from mergekit.measure import run_measure +from mergekit.merge import run_merge from mergekit.plot_tools.plot_tools import ModelGraph, create_app @click.command() @@ -17,7 +17,7 @@ def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory) with open(config_yml, "r", encoding="utf-8") as fp: metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) - metrics = run_measure( + metrics = run_merge( metric_config, out_path=output_path, options=MergeOptions( From f71e36057204ffd1b081a3a02c1b31fdc3f9ff14 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 14:10:05 +0100 Subject: [PATCH 19/64] guard against divide by zero --- mergekit/metric_methods/all_metrics.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index d086140d..a1938270 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -100,10 +100,10 @@ def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # Metric functions -def SMAPE( +def smape( tensors: List[torch.Tensor], **_kwargs ) -> Dict[str, Any]: - """Symmetric Mean Absolute Percentage Error (SMAPE).""" + """Symmetric Mean Absolute Percentage Error (smape).""" numerator = torch.abs(tensors[0] - tensors[1]) denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) @@ -111,13 +111,13 @@ def SMAPE( hist_info = compute_histogram(smape, 100) return { - 'SMAPE Histogram': { + 'smape Histogram': { 'count': hist_info[0], 'edges': hist_info[1], 'widths': hist_info[2] }, - 'SMAPE_mean': smape.mean().item(), - 'SMAPE_std': smape.std().item() + 'smape_mean': smape.mean().item(), + 'smape_std': smape.std().item() } def cossim( @@ -166,7 +166,7 @@ def scale( norm_1 = norm_1.unsqueeze(0) # shape becomes [1, num_heads] # Compute the scale difference between each pair of heads by broadcasting - heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) + heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1 + 1e-10) / 2) res.update({'Scale Heatmap': heatmap}) assert torch.isclose(scale_diff, heatmap.diagonal(), atol=1e-4).all(), "Diagonal elements of scale difference matrix do not match" @@ -239,7 +239,7 @@ def execute( if 'mlp' in self.weight_info.name: res.update(cossim(tensors, return_heatmap=True)) - res.update(SMAPE(tensors)) + res.update(smape(tensors)) res.update(scale(tensors, return_heatmap=True)) res.update(mse(tensors, return_heatmap=False)) # Highly inefficient @@ -285,7 +285,7 @@ def execute( res.update(scale([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)], return_heatmap=True)) - res.update(SMAPE([model_0_heads.view(model_0_heads.shape[0], -1), + res.update(smape([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)])) return res From 7e142664644aaa0867690834d410ddf1e9b3c26b Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 14:10:48 +0100 Subject: [PATCH 20/64] restore plan_to_disk functionality for merging. Move metrics planning to separate case --- mergekit/plan.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/mergekit/plan.py b/mergekit/plan.py index e1d6734a..03c485a5 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -246,17 +246,49 @@ def plan_slice(self, definition: OutputSliceDefinition): cfg_reader=cfg_reader, ) + def metrics_plan_to_disk(self) -> List[Task]: + """Plan the metrics to be streamed to disk, returning a list of tasks.""" + save_tasks = [] + for weight, tensor_task in self._tensors: + save_tasks.append( + tensor_task + ) + + return save_tasks + def plan_to_disk(self, out_path: str) -> List[Task]: """Plan the merge to be streamed to disk, returning a list of tasks.""" self._plan() + if self.config.metric_method: + return self.metrics_plan_to_disk() + + + writer_task = TensorWriterTask( + out_path=out_path, + max_shard_size=self.options.out_shard_size, + safe_serialization=self.options.safe_serialization, + ) save_tasks = [] for weight, tensor_task in self._tensors: save_tasks.append( - tensor_task + SaveTensor( + tensor_name=weight.name, + tensor_task=tensor_task, + writer_task=writer_task, + clone=self.options.clone_tensors, + optional=weight.optional, + dtype=weight.force_dtype or self.config.out_dtype, + ) ) + finalize = FinalizeModel( + tensor_save_tasks=tuple(save_tasks), writer_task=writer_task + ) - return save_tasks + res = save_tasks + [finalize] + if self._tokenizer_task: + res.append(self._tokenizer_task) + return res def plan_in_memory(self) -> List[ReturnTensor]: """Plan the merge to be performed in memory.""" From 0f3430fe741a661cb6a87d900a62ed9eae6deabd Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 15:04:15 +0100 Subject: [PATCH 21/64] add optional interactive plot packages --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a0e9db1..5afa4bb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = ["black~=24.4.2", "isort~=5.13.2", "pre-commit~=3.7.1"] test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] +interactive_plot = ["networkx", "plotly"] [project.urls] repository = "https://github.com/cg123/mergekit" @@ -70,7 +71,4 @@ include = '\.pyi?$' [tool.pytest.ini_options] minversion = "6.0" -filterwarnings = [ - "ignore::pydantic.PydanticDeprecatedSince20:huggingface_hub.*:", -] testpaths = ["tests"] From 8d68e3965c381ba1845cd17e6add8d9880a978d1 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 17 Jun 2024 15:35:54 +0100 Subject: [PATCH 22/64] minor cleanup --- mergekit/graph.py | 2 +- mergekit/plot_tools/plot_tools.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mergekit/graph.py b/mergekit/graph.py index d7c11933..fea69e37 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -37,7 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True): Abstract base class representing a task in a computational graph. This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. - Note that PyDantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. + Pydantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. Attributes: Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution. diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 60719a87..a8a9a1ef 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -67,7 +67,6 @@ def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): plt.close() def plotly_line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): - fig = go.Figure() y = [self.all_stats[layer]['metric'][stat] for layer in self.layer_names] if f'{stat}'.replace('mean', 'std') in self.stat_names: @@ -77,6 +76,11 @@ def plotly_line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): return go.Scatter( x=self.layer_names, y=y, + error_y=dict( + type='data', + array=std_values, + visible=True + ), mode='lines+markers', name='Line Plot' ) From 404e39532435a15fec60541fc2db9620e221d80e Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 19 Jun 2024 12:11:49 +0100 Subject: [PATCH 23/64] Add GQA info to (llama) architecture and refactor --- mergekit/_data/architectures/llama.json | 10 ++++++++-- mergekit/architecture.py | 12 ++++++------ mergekit/metric_methods/all_metrics.py | 7 ++----- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/mergekit/_data/architectures/llama.json b/mergekit/_data/architectures/llama.json index c418f055..afae52eb 100644 --- a/mergekit/_data/architectures/llama.json +++ b/mergekit/_data/architectures/llama.json @@ -21,21 +21,27 @@ { "name": "model.layers.${layer_index}.self_attn.q_proj.weight", "input_space": "h_${layer_index}", + "num_attention_heads": "${num_attention_heads}", "output_space": "attn_qk_${layer_index}" }, { "name": "model.layers.${layer_index}.self_attn.k_proj.weight", "input_space": "h_${layer_index}", - "output_space": "attn_qk_${layer_index}" + "output_space": "attn_qk_${layer_index}", + "num_attention_heads": "${num_attention_heads}", + "num_key_value_heads": "${num_key_value_heads}" }, { "name": "model.layers.${layer_index}.self_attn.v_proj.weight", "input_space": "h_${layer_index}", - "output_space": "attn_v_${layer_index}" + "output_space": "attn_v_${layer_index}", + "num_attention_heads": "${num_attention_heads}", + "num_key_value_heads": "${num_key_value_heads}" }, { "name": "model.layers.${layer_index}.self_attn.o_proj.weight", "input_space": "attn_v_${layer_index}", + "num_attention_heads": "${num_attention_heads}", "output_space": "post_attn_${layer_index}" }, { diff --git a/mergekit/architecture.py b/mergekit/architecture.py index f0efd329..0a38fb6d 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -43,10 +43,10 @@ class WeightInfo(BaseModel, frozen=True): List of alternative names for the weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. - gqa_groups (Optional[int]): - Number of groups for GQA-style weight sharing, if applicable. - num_heads (Optional[int]): - Number of heads for multihead attention, if applicable. + num_key_value_heads (Optional[int]): + Number of key-value heads in the weight, relevant for GQA, if applicable. + num_attention_heads (Optional[int]): + Number of attention heads in the weight, if applicable. """ name: str @@ -57,8 +57,8 @@ class WeightInfo(BaseModel, frozen=True): aliases: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None - gqa_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA - num_heads: Optional[int] = None + num_key_value_heads: Union[int, str, None] = None + num_attention_heads: Union[int, str, None] = None class ProceduralSpaceInfo(BaseModel, frozen=True): diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index a1938270..99e3d928 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -70,8 +70,8 @@ def group_attn_head_weights(k_proj: torch.Tensor, torch.Tensor, torch.Tensor]: - num_heads = weight_info.num_heads - gqa_groups = weight_info.gqa_groups + num_heads = weight_info.num_attention_heads + gqa_groups = num_heads // weight_info.num_key_value_heads k_proj = ungroup_tensor(k_proj, gqa_groups) v_proj = ungroup_tensor(v_proj, gqa_groups) @@ -323,9 +323,6 @@ def group_label(self) -> Optional[str]: # Metric method class AllMetric(MetricMethod): - attn_weight_tensors: Optional[list] = [] - attn_weight_infos: Optional[list] = [] - attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} attn_info_dict: Optional[Dict[str, WeightInfo]] = {} From 8e3c861b6cd91c8a67aa68fbe9513f431a74ddd0 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 19 Jun 2024 12:13:16 +0100 Subject: [PATCH 24/64] Pass GQA info from architecture json all the way to attn metrics. Generalised substitute function in architecture --- mergekit/architecture.py | 46 ++++++++++++++------------ mergekit/metric_methods/all_metrics.py | 13 +++++--- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 0a38fb6d..72100af3 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -179,30 +179,29 @@ class JSONArchitectureDefinition(BaseModel, frozen=True): class TemplateWithArithmetic(string.Template): idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" - -def _template_substitution( - template: str, num_layers: int, layer_idx: Optional[int] = None +def _generic_template_substitution( + template: str, match_str: str = "num_layers", replace_with: int = 0 ) -> str: - if "{" not in template: + if f"{match_str}" not in template: return template substitutions = { - "num_layers": num_layers, - "num_layers+1": num_layers + 1, - "num_layers-1": num_layers - 1, + match_str: replace_with, + f"{match_str}+1": replace_with + 1, + f"{match_str}-1": replace_with - 1, } - if layer_idx is not None: - substitutions.update( - { - "layer_index": layer_idx, - "layer_index+1": layer_idx + 1, - "layer_index-1": layer_idx - 1, - } - ) - return TemplateWithArithmetic(template).substitute(substitutions) +def _template_substitution( + template: str, substitute: Dict[str, Optional[int]] +) -> str: + for key, value in substitute.items(): + if value is not None: + template = _generic_template_substitution(template, key, value) + + return template + class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): definition: JSONArchitectureDefinition @@ -213,18 +212,23 @@ def _substitute( config: PretrainedConfig, layer_idx: Optional[int] = None, ) -> Union[WeightInfo, ProceduralSpaceInfo]: - num_layers = self.num_layers(config) - - obj_dict = item.model_dump(mode="json", exclude_unset=True) + substitute = { + "num_layers": self.num_layers(config), + "layer_index": layer_idx, + "num_attention_heads": getattr(config, "num_attention_heads") if getattr(config, "num_attention_heads") is not None else None, + "num_key_value_heads": getattr(config, "num_key_value_heads") if getattr(config, "num_key_value_heads") is not None else None, + } + + obj_dict = item.model_dump(mode="json", exclude_unset=False) for key in obj_dict: if isinstance(obj_dict[key], str): obj_dict[key] = _template_substitution( - obj_dict[key], num_layers, layer_idx + obj_dict[key], substitute ) elif isinstance(obj_dict[key], list): obj_dict[key] = [ ( - _template_substitution(s, num_layers, layer_idx) + _template_substitution(s, substitute) if isinstance(s, str) else s ) diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index 99e3d928..3b1a9770 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -71,10 +71,13 @@ def group_attn_head_weights(k_proj: torch.Tensor, torch.Tensor]: num_heads = weight_info.num_attention_heads - gqa_groups = num_heads // weight_info.num_key_value_heads + assert num_heads is not None, "Number of attention heads is not defined" + + if getattr(weight_info, 'num_key_value_heads', None) and getattr(weight_info, 'num_key_value_heads', None) != 0: + gqa_groups = num_heads // weight_info.num_key_value_heads - k_proj = ungroup_tensor(k_proj, gqa_groups) - v_proj = ungroup_tensor(v_proj, gqa_groups) + k_proj = ungroup_tensor(k_proj, gqa_groups) + v_proj = ungroup_tensor(v_proj, gqa_groups) k_proj = restructure_tensor(k_proj, num_heads) v_proj = restructure_tensor(v_proj, num_heads) @@ -353,8 +356,8 @@ def make_task( force_dtype=None, optional=False, aliases=None, - gqa_groups=4, # hard-coded for now - num_heads=32 # hard-coded for now + num_key_value_heads=int(infos['k_proj'].num_key_value_heads), + num_attention_heads=int(infos['k_proj'].num_attention_heads) ) self.block_count += 1 return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) From a0e8c271e4b68475d1739ac64c60ff9dd04aa0e3 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 19 Jun 2024 16:20:39 +0100 Subject: [PATCH 25/64] re-organised and simplified dashboard view --- mergekit/metric_methods/all_metrics.py | 6 ++- mergekit/plot_tools/plot_tools.py | 51 ++++++-------------------- 2 files changed, 15 insertions(+), 42 deletions(-) diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index 3b1a9770..fbb935d6 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -155,6 +155,8 @@ def scale( """ Scale difference: ratio of absolute difference to average scale. Complementary to cosine similarity, which measures the angle between two vectors and is invariant to scale. + + values close to 0 indicate that the scales of the two vectors are similar """ norm_0 = torch.norm(tensors[0], dim=1) @@ -241,9 +243,9 @@ def execute( res = {} if 'mlp' in self.weight_info.name: - res.update(cossim(tensors, return_heatmap=True)) + res.update(cossim(tensors, return_heatmap=False)) res.update(smape(tensors)) - res.update(scale(tensors, return_heatmap=True)) + res.update(scale(tensors, return_heatmap=False)) res.update(mse(tensors, return_heatmap=False)) # Highly inefficient return res diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index a8a9a1ef..eba812a6 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -300,16 +300,12 @@ def create_app(nn_graph): html.Div([ html.H1('Network Weights Similarity Visualisation', style={'textAlign': 'center', 'padding': '20px'}), dcc.Dropdown( - id='colour-by-dropdown', + id='line-plot-dropdown', options=[{'label': metric.replace('_', ' ').title(), 'value': metric} for metric in nn_graph.metric_handler.stat_names if 'mean' in metric], value='cossim_mean', style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} ), - dcc.Graph( - id='graph', - figure=nn_graph.plot_graph(), - style={'width': '100%', 'height': '50vh'} - ), + dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}), ], className='container-fluid'), html.Div([ @@ -320,51 +316,40 @@ def create_app(nn_graph): style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} ), dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}), - ], className='container-fluid'), - - html.Div([ - html.H3('Metrics Across Layers', style={'textAlign': 'center'}), - dcc.Dropdown( - id='line-plot-dropdown', - options=[{'label': metric.replace('_', ' ').title(), 'value': metric} for metric in nn_graph.metric_handler.stat_names if 'mean' in metric], - value='cossim_mean', - style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} - ), - dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}), - ], className='container-fluid'), + ], className='container-fluid') ]) @app.callback( Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), - [Input('graph', 'clickData')] + [Input('line-plot', 'clickData')] ) def update_metric_dropdown_options(clickData): if clickData is None: - return [] + return [], None try: - node_name = clickData['points'][0]['text'] + node_name = clickData['points'][0]['x'] options = list(nn_graph.metric_handler.stats_at_layer(node_name).keys()) options = [option for option in options if 'std' not in option] options = [ {'label': option.replace('_', ' ').title(), 'value': option} for option in options if 'mean' not in option ] - return options, options[0]['value'] + return options, options[0]['value'] if options else None except (KeyError, IndexError, TypeError) as e: print(f"Error processing clickData: {e}") - return [] + return [], None @app.callback( Output('node-details-plot', 'figure'), - [Input('graph', 'clickData'), Input('metric-dropdown', 'value')], + [Input('line-plot', 'clickData'), Input('metric-dropdown', 'value')], ) def display_node_data(clickData, selected_metric): if clickData is None: return go.Figure() try: - node_name = clickData['points'][0]['text'] + node_name = clickData['points'][0]['x'] except (KeyError, IndexError, TypeError) as e: print(f"Error processing clickData: {e}") return go.Figure() @@ -395,20 +380,6 @@ def display_node_data(clickData, selected_metric): ) def update_line_plot(selected_metric): fig = go.Figure() - stat_values = [nn_graph.metric_handler.stats_at_layer(layer)[selected_metric] for layer in nn_graph.metric_handler.layer_names] - fig.add_trace(go.Scatter(x=nn_graph.metric_handler.layer_names, y=stat_values, mode='lines+markers')) - fig.update_layout( - title=f"{selected_metric.replace('_', ' ').title()} Across Layers", - xaxis_title="Layer", - yaxis_title=selected_metric.replace('_', ' ').title() - ) + fig.add_trace(nn_graph.metric_handler.plotly_line_plot(selected_metric)) return fig - - @app.callback( - Output('graph', 'figure'), - [Input('colour-by-dropdown', 'value')] - ) - def update_graph_colour(colour_by): - return nn_graph.plot_graph(colour_by=colour_by) - return app From 7e2b552a34eaa6d0732b3193cdc9c139bfbf1471 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 19 Jun 2024 17:52:14 +0100 Subject: [PATCH 26/64] colour-categorise lineplot points by layertime --- mergekit/plot_tools/plot_tools.py | 74 ++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index eba812a6..43cac691 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -4,6 +4,7 @@ import networkx as nx import plotly.graph_objects as go import matplotlib.pyplot as plt +import matplotlib.colors as mcolors import dash from dash import dcc, html from dash.dependencies import Input, Output @@ -66,24 +67,64 @@ def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): plt.show() plt.close() - def plotly_line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): + def categorise_layers(self, layer_names): + # Hardcoded for now - can be extended to include more categories or further generalised + categories = [] + for name in layer_names: + if 'Attention Block' in name: + categories.append('Attention Block') + elif 'model.layer' in name: + categories.append('Model Layer') + else: + categories.append('Other') + return categories + def plotly_line_plot(self, stat: str, save_to: Optional[str] = None, **kwargs): y = [self.all_stats[layer]['metric'][stat] for layer in self.layer_names] - if f'{stat}'.replace('mean', 'std') in self.stat_names: - std_stat = f'{stat}'.replace('mean', 'std') + std_stat = f'{stat}'.replace('mean', 'std') + if std_stat in self.stat_names: std_values = [self.all_stats[layer]['metric'].get(std_stat) for layer in self.layer_names] + else: + std_values = [0] * len(self.layer_names) + + layer_categories = self.categorise_layers(self.layer_names) + unique_categories = list(set(layer_categories)) + + # Assign a unique color to each category + cmap = plt.get_cmap('jet', len(unique_categories)) + colors = [mcolors.to_hex(cmap(i)) for i in range(len(unique_categories))] - return go.Scatter( - x=self.layer_names, - y=y, - error_y=dict( - type='data', - array=std_values, - visible=True - ), - mode='lines+markers', - name='Line Plot' + category_styles = {cat: colors[i % len(colors)] for i, cat in enumerate(unique_categories)} + + fig = go.Figure() + + for category in unique_categories: + y_category = [y[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + std_category = [std_values[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + + fig.add_trace(go.Scatter( + x=self.layer_names, + y=y_category, + error_y=dict( + type='data', + array=std_category, + visible=True + ), + mode='markers', + name=category, + marker=dict(color=category_styles[category]) + )) + + fig.update_layout( + title=f"{stat.replace('_', ' ').title()} Across Layers", + xaxis_title="Layer", + yaxis_title=stat.replace('_', ' ').title() ) + + if save_to: + fig.write_image(save_to) + + return fig def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): """ @@ -379,7 +420,8 @@ def display_node_data(clickData, selected_metric): [Input('line-plot-dropdown', 'value')] ) def update_line_plot(selected_metric): - fig = go.Figure() - fig.add_trace(nn_graph.metric_handler.plotly_line_plot(selected_metric)) - return fig + # fig = go.Figure() + # fig.add_trace(nn_graph.metric_handler.plotly_line_plot(selected_metric)) + # return fig + return nn_graph.metric_handler.plotly_line_plot(selected_metric) return app From 69c3b15d57926386a045175230761580b75ca225 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Mon, 24 Jun 2024 14:51:12 +0100 Subject: [PATCH 27/64] restructure and refactor results and plotting --- .../{metrics-1.yml => metrics-llama-1v2.yml} | 0 examples/metrics-llama-codev2.yml | 6 + mergekit/metric_methods/all_metrics.py | 263 ++++++++----- mergekit/plot_tools/plot_tools.py | 356 +++++++----------- 4 files changed, 305 insertions(+), 320 deletions(-) rename examples/{metrics-1.yml => metrics-llama-1v2.yml} (100%) create mode 100644 examples/metrics-llama-codev2.yml diff --git a/examples/metrics-1.yml b/examples/metrics-llama-1v2.yml similarity index 100% rename from examples/metrics-1.yml rename to examples/metrics-llama-1v2.yml diff --git a/examples/metrics-llama-codev2.yml b/examples/metrics-llama-codev2.yml new file mode 100644 index 00000000..13413ce3 --- /dev/null +++ b/examples/metrics-llama-codev2.yml @@ -0,0 +1,6 @@ +models: + - model: meta-llama/CodeLlama-7b-Python-hf + - model: meta-llama/Llama-2-7b-hf + +metric_method: all +dtype: float32 diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index fbb935d6..f996b46e 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -26,6 +26,81 @@ import numpy as np + +from dataclasses import dataclass, field +from typing import Dict, List, Any + + +# Results +# └── layers: Dict[str, Layer] +# └── Layer +# ├── name: str +# ├── metrics: Dict[str, Metric] +# │ └── Metric +# │ ├── name: str +# │ ├── histogram: Histogram (optional) +# │ │ ├── count: List[float] +# │ │ ├── edges: List[float] +# │ │ └── widths: List[float] +# │ ├── mean_std: MeanStd (optional) +# │ │ ├── mean: float +# │ │ └── std: float (optional) +# │ ├── heatmap: Heatmap (optional) +# │ │ └── data: torch.Tensor +# │ ├── value: float (optional) +# │ └── additional_data: Dict[str, Any] +# └── weight_info: WeightInfo + +@dataclass +class MeanStd: + mean: float + std: Optional[float] = None + +@dataclass +class Heatmap: + data: torch.Tensor + +@dataclass +class Histogram: + count: List[float] + edges: List[float] + widths: List[float] + +@dataclass +class Metric: + histogram: Histogram = None + mean_std: MeanStd = None + heatmap: Heatmap = None + value: float = None + additional_data: Dict[str, Any] = field(default_factory=dict) + + def filled_attributes(self) -> List[str]: + filled_attrs = [] + for attr, value in self.__dict__.items(): + if value is not None: + filled_attrs.append(attr) + return filled_attrs + +@dataclass +class Layer: + metrics: Dict[str, Metric] + weight_info: WeightInfo + + def metrics_with_property(self, prop: str) -> List[str]: + return [name for name, metric in self.metrics.items() if getattr(metric, prop) is not None] + +class Results: + # Class to store the statistics for each layer, redundant - remove or add more functionality + def __init__(self): + self.layers: Dict[str, Layer] = {} + + def add_layer(self, layer: Layer, name: str): + if name not in self.layers.keys(): + self.layers[name] = layer + + def get_metric(self, layer_name: str, metric_name: str) -> Metric: + return self.get_layer(layer_name, metric_name) + def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): """Validate tensor shapes and count.""" unique_shapes = set(t.shape for t in tensors) @@ -105,7 +180,7 @@ def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: def smape( tensors: List[torch.Tensor], **_kwargs -) -> Dict[str, Any]: +) -> Metric: """Symmetric Mean Absolute Percentage Error (smape).""" numerator = torch.abs(tensors[0] - tensors[1]) @@ -113,45 +188,34 @@ def smape( smape = torch.mean(torch.div(numerator, denominator), dim=1) hist_info = compute_histogram(smape, 100) - return { - 'smape Histogram': { - 'count': hist_info[0], - 'edges': hist_info[1], - 'widths': hist_info[2] - }, - 'smape_mean': smape.mean().item(), - 'smape_std': smape.std().item() - } + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=smape.mean().item(), std=smape.std().item()) + ) def cossim( tensors: List[torch.Tensor], return_heatmap=False, **_kwargs -) -> torch.Tensor: +) -> Metric: """Cosine similarity""" cossim = F.cosine_similarity(tensors[0], tensors[1], dim=1) - res = {} - if return_heatmap: - res.update({'Cossim Heatmap': cossim_heatmap(tensors[0], tensors[1])}) + heatmap = cossim_heatmap(tensors[0], tensors[1]) assert torch.isclose(cossim, cossim, atol=1e-6).all(), "NaNs in cosine similarity" - assert torch.isclose(cossim, cossim_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-4).all(), "Diagonal elements of cosine similarity matrix do not match" + assert torch.isclose(cossim, cossim_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-2).all(), "Diagonal elements of cosine similarity matrix do not match" hist_info = compute_histogram(cossim, 100) - res.update({ - 'cossim Histogram': { - 'count': hist_info[0], - 'edges': hist_info[1], - 'widths': hist_info[2] - }, - 'cossim_mean': cossim.mean().item(), - 'cossim_std': cossim.std().item() - }) - return res + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=cossim.mean().item(), std=cossim.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) def scale( tensors: List[torch.Tensor], return_heatmap=False, **_kwargs -) -> torch.Tensor: +) -> Metric: """ Scale difference: ratio of absolute difference to average scale. Complementary to cosine similarity, which measures the angle between two vectors and is invariant to scale. @@ -172,29 +236,21 @@ def scale( # Compute the scale difference between each pair of heads by broadcasting heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1 + 1e-10) / 2) - res.update({'Scale Heatmap': heatmap}) assert torch.isclose(scale_diff, heatmap.diagonal(), atol=1e-4).all(), "Diagonal elements of scale difference matrix do not match" - hist_info = compute_histogram(scale_diff, 100) - res.update({ - 'scale Histogram': { - 'count': hist_info[0], - 'edges': hist_info[1], - 'widths': hist_info[2] - }, - 'scale_mean': scale_diff.mean().item(), - 'scale_std': scale_diff.std().item() - }) - return res + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=scale_diff.mean().item(), std=scale_diff.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) def mse( tensors: List[torch.Tensor], return_heatmap: bool =False, **_kwargs -) -> torch.Tensor: +) -> Metric: """Mean squared error (MSE).""" - res = {} - if return_heatmap: # Expand dimensions for broadcasting tensors_0_exp = tensors[0].unsqueeze(1) # shape becomes [num_heads, 1, ...] @@ -206,22 +262,19 @@ def mse( # Compute mean over all dimensions except the first two heatmap = diffs.mean(dim=tuple(range(2, diffs.dim()))).numpy() - res['MSE Attn Heatmap'] = heatmap - squared_diff = (tensors[0] - tensors[1]) ** 2 mse_per_neuron = torch.mean(squared_diff, dim=1) hist_info = compute_histogram(mse_per_neuron, 100) - res.update({ - 'mse Histogram': { - 'count': hist_info[0], - 'edges': hist_info[1], - 'widths': hist_info[2] - }, - 'mse_mean': mse_per_neuron.mean().item(), - 'mse_std': mse_per_neuron.std().item() - }) - return res + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=mse_per_neuron.mean().item(), std=mse_per_neuron.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +# Tensor Analysis (number of tensors can vary) + # Tasks @@ -238,17 +291,17 @@ def arguments(self) -> Dict[str, Task]: def execute( self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs ) -> torch.Tensor: - tensors = list(tensors.values()) - validate_tensors(tensors, self.weight_info, expected_tensors=2) - res = {} - if 'mlp' in self.weight_info.name: - - res.update(cossim(tensors, return_heatmap=False)) - res.update(smape(tensors)) - res.update(scale(tensors, return_heatmap=False)) - res.update(mse(tensors, return_heatmap=False)) # Highly inefficient - - return res + weights = list(tensors.values()) + validate_tensors(weights, self.weight_info, expected_tensors=2) + out = Layer(metrics={}, + weight_info=self.weight_info) + + out.metrics['cossim'] = cossim(weights, return_heatmap=False) + out.metrics['smape'] = smape(weights) + out.metrics['scale'] = scale(weights, return_heatmap=False) + out.metrics['mse'] = mse(weights, return_heatmap=False) # Highly inefficient + + return out def group_label(self) -> Optional[str]: return self.gather_tensors.group_label() @@ -266,10 +319,10 @@ def arguments(self) -> Dict[str, Task]: return self.weights def execute( - self, k_proj, v_proj, q_proj, o_proj, **_kwargs + self, k_proj: torch.Tensor, v_proj: torch.Tensor, q_proj: torch.Tensor, o_proj: torch.Tensor, **_kwargs ) -> torch.Tensor: # Add metrics for attention weights - res = {} + models = list(q_proj.keys()) k_proj_0, v_proj_0, q_proj_0, o_proj_0 = group_attn_head_weights(k_proj[models[0]], q_proj[models[0]], v_proj[models[0]], o_proj[models[0]], self.weight_info) @@ -277,27 +330,31 @@ def execute( # Metrics for K, V, Q, O projections + + # Metrics for heads + model_0_heads = torch.cat([k_proj_0, v_proj_0, q_proj_0, o_proj_0], dim=1) model_1_heads = torch.cat([k_proj_1, v_proj_1, q_proj_1, o_proj_1], dim=1) - # Metrics for heads - res.update(mse([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True)) - res.update(cossim([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True)) - res.update(scale([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True)) - res.update(smape([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)])) - - return res + out = Layer(metrics={}, + weight_info=self.weight_info) + + out.metrics['cossim'] = cossim([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True) + out.metrics['smape'] = smape([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)]) + out.metrics['scale'] = scale([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True) + out.metrics['mse'] = mse([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=False) + + return out def group_label(self) -> Optional[str]: - # Use max of the group labels - return max([gather_tensor.group_label() for gather_tensor in list(self.weights.values())]) # Check this (X) + return max([gather_tensor.group_label() for gather_tensor in list(self.weights.values())]) def __hash__(self): return hash(self.weight_info) @@ -343,31 +400,31 @@ def make_task( ) -> Task: if 'self_attn' in output_weight.name: - # collect all attention weights - for part in self.attn_parts: # also check only one key - if part in output_weight.name: - self.attn_weight_dict[part] = tensors - self.attn_info_dict[part] = output_weight - - # if all attention weights are collected, create attention task - if set(list(self.attn_weight_dict.keys())) == set(self.attn_parts): - weights, infos = self.attn_weight_dict, self.attn_info_dict - self.attn_weight_dict, self.attn_info_dict = {}, {} - weight_info = WeightInfo( - name=f"Attention Block {self.block_count}", - force_dtype=None, - optional=False, - aliases=None, - num_key_value_heads=int(infos['k_proj'].num_key_value_heads), - num_attention_heads=int(infos['k_proj'].num_attention_heads) - ) - self.block_count += 1 - return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) - if 'mlp' in output_weight.name: + return self.group_attn_heads(tensors, output_weight) + elif 'mlp' in output_weight.name: return MLPTask( gather_tensors=tensors, weight_info=output_weight, + intra_model_metrics=parameters['intra_model_metrics'] + ) + else: + # Executor expects a task to be returned + return DummyTask(gather_tensors=tensors, weight_info=output_weight) + + # if all attention weights are collected, create attention task + if set(list(self.attn_weight_dict.keys())) == set(self.attn_parts): + weights, infos = self.attn_weight_dict, self.attn_info_dict + self.attn_weight_dict, self.attn_info_dict = {}, {} + weight_info = WeightInfo( + name=f"Attention Block {self.block_count}", + force_dtype=None, + optional=False, + aliases=None, + num_key_value_heads=int(infos['k_proj'].num_key_value_heads), + num_attention_heads=int(infos['k_proj'].num_attention_heads) ) + self.block_count += 1 + return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) else: # Executor expects a task to be returned return DummyTask(gather_tensors=tensors, weight_info=output_weight) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 43cac691..9304dd1a 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -7,7 +7,8 @@ import matplotlib.colors as mcolors import dash from dash import dcc, html -from dash.dependencies import Input, Output +from dash.dependencies import Input, Output, State +from mergekit.metric_methods.all_metrics import Layer, Results class MetricsHandler: """ @@ -30,29 +31,27 @@ class MetricsHandler: plot_node_hist: Plot a histogram of the stat for a specific layer. """ def __init__(self): - self.all_stats: Dict[str, Dict[str, Any]] = {} - self.stat_names: List = [] self.layer_names: List[str] = [] + self.results = Results() + self.stat_names: List[str] = [] - def load_metrics(self, metrics: List[Tuple[Task, Dict[str, Any]]]): + def load_metrics(self, metrics: List[Tuple[Task, Layer]]): for task, metric in metrics: if metric is not None: - self.all_stats[task.weight_info.name] = {'metric':metric, - 'weight_info':task.weight_info} - self.layer_names.append(task.weight_info.name) - self.stat_names.extend(metric.keys()) - + self.results.add_layer(metric, name=task.weight_info.name) + self.stat_names.extend(list(metric.metrics.keys())) + self.layer_names = list(self.results.layers.keys()) self.stat_names = list(set(self.stat_names)) def stats_at_layer(self, layer_name: str) -> Dict[str, Any]: - if layer_name not in self.all_stats: + if layer_name not in self.results.layers: raise ValueError(f"Layer {layer_name} not found") - return self.all_stats[layer_name]['metric'] + return self.results.layers[layer_name] def info_at_layer(self, layer_name: str): - if layer_name not in self.all_stats: + if layer_name not in self.results.layers: raise ValueError(f"Layer {layer_name} not found") - return self.all_stats[layer_name]['weight_info'] + return self.results.layers[layer_name].weight_info def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): fig, ax = plt.subplots() @@ -68,41 +67,51 @@ def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): plt.close() def categorise_layers(self, layer_names): - # Hardcoded for now - can be extended to include more categories or further generalised + # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config categories = [] for name in layer_names: if 'Attention Block' in name: categories.append('Attention Block') - elif 'model.layer' in name: - categories.append('Model Layer') + elif 'mlp' in name: + categories.append('MLP') else: categories.append('Other') return categories - def plotly_line_plot(self, stat: str, save_to: Optional[str] = None, **kwargs): - y = [self.all_stats[layer]['metric'][stat] for layer in self.layer_names] - std_stat = f'{stat}'.replace('mean', 'std') - if std_stat in self.stat_names: - std_values = [self.all_stats[layer]['metric'].get(std_stat) for layer in self.layer_names] - else: - std_values = [0] * len(self.layer_names) + def plotly_line_plot(self, stat: str, **kwargs): + """ + Plot the stat values across layers using Plotly. + + Args: + stat (str): The name of the stat to plot. + Returns: + List[go.Scatter]: List of Plotly Scatter objects. + """ + + if stat not in self.stat_names: + print(f"Stat {stat} not found") + return + + means = [self.results.layers[layer].metrics[stat].mean_std.mean for layer in self.layer_names] + stds = [self.results.layers[layer].metrics[stat].mean_std.std for layer in self.layer_names] + layer_categories = self.categorise_layers(self.layer_names) unique_categories = list(set(layer_categories)) # Assign a unique color to each category - cmap = plt.get_cmap('jet', len(unique_categories)) + cmap = plt.get_cmap('Set1', len(unique_categories)) colors = [mcolors.to_hex(cmap(i)) for i in range(len(unique_categories))] category_styles = {cat: colors[i % len(colors)] for i, cat in enumerate(unique_categories)} - fig = go.Figure() + traces = [] for category in unique_categories: - y_category = [y[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] - std_category = [std_values[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + y_category = [means[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + std_category = [stds[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] - fig.add_trace(go.Scatter( + traces.append(go.Scatter( x=self.layer_names, y=y_category, error_y=dict( @@ -114,17 +123,7 @@ def plotly_line_plot(self, stat: str, save_to: Optional[str] = None, **kwargs): name=category, marker=dict(color=category_styles[category]) )) - - fig.update_layout( - title=f"{stat.replace('_', ' ').title()} Across Layers", - xaxis_title="Layer", - yaxis_title=stat.replace('_', ' ').title() - ) - - if save_to: - fig.write_image(save_to) - - return fig + return traces def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): """ @@ -132,25 +131,14 @@ def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): Args: ax: The matplotlib Axes object. - stat_values (List[float]): The values of the stat to plot. - std_values (Optional[List[float]]): The standard deviation values for error bars. - **kwargs: Additional keyword arguments for plotting. + stat (str): The name of the stat to plot. + plot_kwargs: Additional keyword arguments for plotting. """ - std_values = None - if f'{stat}_mean' in self.stat_names: - std_stat = f"{stat}_std" - stat = f'{stat}_mean' - if std_stat in self.stat_names: - std_values = [self.all_stats[layer]['metric'].get(std_stat) for layer in self.layer_names] - - assert (stat in self.stat_names), f"Stat {stat} not found" - stat_values = [self.all_stats[layer]['metric'][stat] for layer in self.layer_names] - if std_values: - ax.errorbar(self.layer_names, stat_values, yerr=std_values, fmt='-o', **plot_kwargs) - else: - ax.plot(stat_values, **plot_kwargs) - - def heatmap_plot(self, layer_name:str, stat:str): + means = [self.results.layers[layer].metrics[stat].mean for layer in self.layer_names] + stds = [self.results.layers[layer].metrics[stat].std for layer in self.layer_names] + ax.errorbar(self.layer_names, means, yerr=stds, fmt='-o', **plot_kwargs) + + def plot_node_heatmap(self, layer_name:str, stat:str): """ Plot the stat values as a heatmap. @@ -160,7 +148,7 @@ def heatmap_plot(self, layer_name:str, stat:str): Returns: go.Heatmap: Plotly Heatmap object. """ - heatmap = self.all_stats[layer_name]['metric'][stat] + heatmap = self.results.layers[layer_name].metrics[stat].heatmap.data return go.Heatmap( z=heatmap, @@ -189,11 +177,12 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): def plot_node_hist(self, layer_name: str, stat: str): - bin_counts, bin_edges, bin_widths = self.stats_at_layer(layer_name)[stat].values() + hist = self.stats_at_layer(layer_name).metrics[stat].histogram + count, edges, widths = hist.count, hist.edges, hist.widths return go.Bar( - x=bin_edges[:-1], - y=bin_counts, - width=bin_widths, + x=edges[:-1], + y=count, + width=widths, marker=dict( color='blue', line=dict( @@ -202,6 +191,17 @@ def plot_node_hist(self, layer_name: str, stat: str): ) ) ) + + def node_plot_options(self, node): + layer = self.results.layers[node] + + return [ + {"label": f"{metric.title()} Histogram", "value": [metric, 'histogram']} + for metric in layer.metrics_with_property('histogram') + ] + [ + {"label": f"{metric.title()} Heatmap", "value": [metric, 'heatmap']} + for metric in layer.metrics_with_property('heatmap') + ] class ModelGraph: @@ -218,7 +218,7 @@ def _find_common_parts(self) -> List[str]: Find common parts in all task names. """ common_parts = None - for task_name, _ in self.metric_handler.all_stats.items(): + for task_name in self.metric_handler.results.layers.keys(): parts = task_name.split('.') if common_parts is None: common_parts = set(parts) @@ -226,17 +226,8 @@ def _find_common_parts(self) -> List[str]: common_parts.intersection_update(parts) return list(common_parts) - - def _remove_common_parts(self, name: str) -> str: - """ - Remove common parts from the task name. - """ - parts = name.split('.') - cleaned_parts = [part for part in parts if part not in self.common_parts] - return '.'.join(cleaned_parts) - def _parse_task_names(self): - for task_name, _ in self.metric_handler.all_stats.items(): + for task_name in self.metric_handler.results.layers.keys(): self.hierarchy.append(task_name) def _add_nodes_and_edges(self, hierarchy): @@ -250,178 +241,109 @@ def _add_nodes_and_edges(self, hierarchy): prev = name def construct_graph(self): - self._add_nodes_and_edges(self.hierarchy) - - def plot_graph(self, colour_by='cossim_mean', save_to: str = None): - """ - Plot the graph using Plotly for interactivity. - """ - # Manually set positions for a straight line layout. - # Not yet implemented for more complex layouts with Parallel paths - pos = {node: (i, i/5) for i, node in enumerate(self.graph.nodes())} - - edge_x = [] - edge_y = [] - for edge in self.graph.edges(): - x0, y0 = pos[edge[0]] - x1, y1 = pos[edge[1]] - edge_x.extend([x0, x1, None]) - edge_y.extend([y0, y1, None]) - - edge_trace = go.Scatter( - x=edge_x, y=edge_y, - line=dict(width=1, color='#888'), - hoverinfo='none', - mode='lines') - - # Find all metrics that contain 'mean' in their keys - metrics_to_plot = [m for m in self.metric_handler.stat_names if 'mean' in m] - - node_x,node_y,node_text,hover_text = [], [], [], [] - node_values = {metric: [] for metric in metrics_to_plot} - - for node in self.graph.nodes(): - x, y = pos[node] - node_x.append(x) - node_y.append(y) - metric_values = self.metric_handler.stats_at_layer(node) - - # Build the text for each node - hover = self._remove_common_parts(node) - for metric in metrics_to_plot: - if metric in metric_values: - value = metric_values[metric] - hover += f"
{metric.replace('_', ' ').title()}: {value:.4f}{'%' if 'SMAPE' in metric else ''}" - node_values[metric].append(value) - - node_text.append(node) - hover_text.append(hover) - - node_colors = [value for value in node_values[colour_by]] - - node_trace = go.Scatter( - x=node_x, y=node_y, - mode='markers+text', - text=node_text, - textposition='top center', - hoverinfo='text', - hovertext=hover_text, - marker=dict( - showscale=True, - colorscale='Viridis', - color=node_colors, - cmin=min(node_values[colour_by]), - cmax=max(node_values[colour_by]), - size=10, - colorbar=dict( - thickness=15, - title=colour_by.replace('_', ' ').title(), - xanchor='left', - titleside='right', - ), - line_width=2)) - - fig = go.Figure(data=[edge_trace, node_trace], - layout=go.Layout( - showlegend=False, - hovermode='closest', - margin=dict(b=0, l=0, r=0, t=0), - xaxis=dict(showgrid=False, zeroline=False), - yaxis=dict(showgrid=False, zeroline=False))) - - if save_to: - fig.write_html(save_to) - return fig - + self._add_nodes_and_edges(self.hierarchy) def create_app(nn_graph): app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) app.layout = html.Div([ - html.Div([ - html.H1('Network Weights Similarity Visualisation', style={'textAlign': 'center', 'padding': '20px'}), - dcc.Dropdown( - id='line-plot-dropdown', - options=[{'label': metric.replace('_', ' ').title(), 'value': metric} for metric in nn_graph.metric_handler.stat_names if 'mean' in metric], - value='cossim_mean', - style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} - ), - dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}), - ], className='container-fluid'), - - html.Div([ - html.H3('Node Metrics', style={'textAlign': 'center'}), - dcc.Dropdown( - id='metric-dropdown', - options=[], - style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} - ), - dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}), - ], className='container-fluid') + create_header(), + create_line_plot_section(nn_graph), + create_node_metrics_section() ]) + register_callbacks(app, nn_graph) + + return app + +def create_header(): + return html.H1('Network Weights Similarity Visualization', + style={'textAlign': 'center', 'padding': '20px'}) + +def create_line_plot_section(nn_graph): + return html.Div([ + dcc.Dropdown( + id='line-plot-dropdown', + options=[{'label': metric.replace('_', ' ').title(), 'value': metric} + for metric in nn_graph.metric_handler.stat_names], + value='cossim', + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} + ), + dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}) + ], className='container-fluid') + +def create_node_metrics_section(): + return html.Div([ + html.H3('Node Metrics', style={'textAlign': 'center'}), + dcc.Dropdown( + id='metric-dropdown', + options=[], + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'}, + value=None + ), + dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}) + ], className='container-fluid') + + +def register_callbacks(app, nn_graph): @app.callback( - Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), - [Input('line-plot', 'clickData')] + Output('metric-dropdown', 'options'), + Output('metric-dropdown', 'value'), + Input('line-plot', 'clickData') ) def update_metric_dropdown_options(clickData): - if clickData is None: + if not clickData: return [], None - + try: node_name = clickData['points'][0]['x'] - options = list(nn_graph.metric_handler.stats_at_layer(node_name).keys()) - options = [option for option in options if 'std' not in option] - options = [ - {'label': option.replace('_', ' ').title(), 'value': option} for option in options if 'mean' not in option - ] - return options, options[0]['value'] if options else None + options = nn_graph.metric_handler.node_plot_options(node_name) + return options, options[0]['value'] if options else ([], None) - except (KeyError, IndexError, TypeError) as e: + except (KeyError, IndexError, AttributeError) as e: print(f"Error processing clickData: {e}") return [], None @app.callback( Output('node-details-plot', 'figure'), - [Input('line-plot', 'clickData'), Input('metric-dropdown', 'value')], + Input('metric-dropdown', 'value'), + State('line-plot', 'clickData') ) - def display_node_data(clickData, selected_metric): - if clickData is None: + def display_node_data(selected_metric, clickData): + if not clickData: return go.Figure() try: node_name = clickData['points'][0]['x'] - except (KeyError, IndexError, TypeError) as e: - print(f"Error processing clickData: {e}") - return go.Figure() + if not selected_metric: + selected_metric = nn_graph.metric_handler.node_plot_options(node_name)[0]['value'] - fig = go.Figure() - if 'histogram' in selected_metric or 'Histogram' in selected_metric: - trace = nn_graph.metric_handler.plot_node_hist(node_name, stat=selected_metric) - fig.add_trace(trace) - fig.update_layout( - title=f"Metrics for {node_name} | {selected_metric}", - xaxis_title="Metric", - yaxis_title="Value" - ) - elif 'heatmap' in selected_metric or 'Heatmap' in selected_metric: - trace = nn_graph.metric_handler.heatmap_plot(layer_name=node_name, stat=selected_metric) - fig.add_trace(trace) - fig.update_layout( - title=f"{node_name} | {selected_metric}", - xaxis_title="Model 1 Head", - yaxis_title="Model 0 Head" - ) + metric_name, plot_type = selected_metric - return fig + if 'histogram' in plot_type.lower(): + trace = nn_graph.metric_handler.plot_node_hist(node_name, stat=metric_name) + return create_figure(trace, f"Histogram for {node_name} | {metric_name}", "Metric", "Value") + elif 'heatmap' in plot_type.lower(): + trace = nn_graph.metric_handler.plot_node_heatmap(layer_name=node_name, stat=metric_name) + return create_figure(trace, f"Heatmap for {node_name} | {metric_name}", "Model 1 Head", "Model 0 Head") + + return go.Figure() + except (KeyError, IndexError, AttributeError) as e: + print(f"Error processing node data: {e}") + return go.Figure() @app.callback( Output('line-plot', 'figure'), - [Input('line-plot-dropdown', 'value')] + Input('line-plot-dropdown', 'value') ) def update_line_plot(selected_metric): - # fig = go.Figure() - # fig.add_trace(nn_graph.metric_handler.plotly_line_plot(selected_metric)) - # return fig - return nn_graph.metric_handler.plotly_line_plot(selected_metric) - return app + if not selected_metric: + return go.Figure() + return go.Figure(data=nn_graph.metric_handler.plotly_line_plot(selected_metric)) + +def create_figure(trace, title, xaxis_title, yaxis_title): + return go.Figure(data=[trace], layout=go.Layout( + title=title, + xaxis_title=xaxis_title, + yaxis_title=yaxis_title + )) From f36fd7063c4a145f560fd44a7ac5c17049b77a49 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 26 Jun 2024 16:48:17 +0100 Subject: [PATCH 28/64] restructure metrics storage, remove graph from plot, remove redundancies and refactor --- mergekit/metric_methods/all_metrics.py | 71 ------ mergekit/metric_methods/base.py | 174 ++++++++++++- mergekit/plot_tools/plot_tools.py | 323 +++++++++++++------------ 3 files changed, 327 insertions(+), 241 deletions(-) diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index f996b46e..fb17c26c 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -30,77 +30,6 @@ from dataclasses import dataclass, field from typing import Dict, List, Any - -# Results -# └── layers: Dict[str, Layer] -# └── Layer -# ├── name: str -# ├── metrics: Dict[str, Metric] -# │ └── Metric -# │ ├── name: str -# │ ├── histogram: Histogram (optional) -# │ │ ├── count: List[float] -# │ │ ├── edges: List[float] -# │ │ └── widths: List[float] -# │ ├── mean_std: MeanStd (optional) -# │ │ ├── mean: float -# │ │ └── std: float (optional) -# │ ├── heatmap: Heatmap (optional) -# │ │ └── data: torch.Tensor -# │ ├── value: float (optional) -# │ └── additional_data: Dict[str, Any] -# └── weight_info: WeightInfo - -@dataclass -class MeanStd: - mean: float - std: Optional[float] = None - -@dataclass -class Heatmap: - data: torch.Tensor - -@dataclass -class Histogram: - count: List[float] - edges: List[float] - widths: List[float] - -@dataclass -class Metric: - histogram: Histogram = None - mean_std: MeanStd = None - heatmap: Heatmap = None - value: float = None - additional_data: Dict[str, Any] = field(default_factory=dict) - - def filled_attributes(self) -> List[str]: - filled_attrs = [] - for attr, value in self.__dict__.items(): - if value is not None: - filled_attrs.append(attr) - return filled_attrs - -@dataclass -class Layer: - metrics: Dict[str, Metric] - weight_info: WeightInfo - - def metrics_with_property(self, prop: str) -> List[str]: - return [name for name, metric in self.metrics.items() if getattr(metric, prop) is not None] - -class Results: - # Class to store the statistics for each layer, redundant - remove or add more functionality - def __init__(self): - self.layers: Dict[str, Layer] = {} - - def add_layer(self, layer: Layer, name: str): - if name not in self.layers.keys(): - self.layers[name] = layer - - def get_metric(self, layer_name: str, metric_name: str) -> Metric: - return self.get_layer(layer_name, metric_name) - def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): """Validate tensor shapes and count.""" unique_shapes = set(t.shape for t in tensors) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index df42a39c..bb57bc97 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -13,17 +13,173 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -from abc import ABC, abstractmethod -from typing import Any, List, Optional - -from pydantic import BaseModel +from typing import Any, List, Optional, Dict from mergekit.architecture import WeightInfo -from mergekit.common import ImmutableMap, ModelReference -from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors +from mergekit.common import ModelReference -from mergekit.merge_methods.base import MergeMethod, ConfigParameterDef +from mergekit.merge_methods.base import MergeMethod +from dataclasses import dataclass, field +from collections import defaultdict +import torch class MetricMethod(MergeMethod): - pass \ No newline at end of file + pass + +# Structure of the results object + +# Results +# └── layers: Dict[str, Layer] +# └── Layer +# ├── weight_info: WeightInfo +# └── metrics: Dict[str, List[Metric]] +# └── Metric +# ├── histogram: Optional[Histogram] +# ├── mean_std: Optional[MeanStd] +# ├── heatmap: Optional[Heatmap] +# └── model_ref: Optional[ModelReference] + +# Each Layer stores metrics under a key (e.g. 'cosine_similarity') in a dictionary. +# The values stored under each key are a **list** of Metric objects. This is to allow for a single metric type to be computed for each model. + # For metrics which compare between models, (e.g. cosine similarity) the list will contain a single Metric object storing the comparison data. + # For metrics which analyse individual models, (e.g. intrinsic dimension) the list will contain a Metric object for each model. + + +@dataclass +class MeanStd: + mean: float + std: Optional[float] = None + +@dataclass +class Heatmap: + data: torch.Tensor + +@dataclass +class Histogram: + count: List[float] + edges: List[float] + widths: List[float] + +@dataclass +class Metric: + histogram: Optional[Histogram] = None + mean_std: Optional[MeanStd] = None + heatmap: Optional[Heatmap] = None + model_ref: Optional[ModelReference] = None # For intra-model metrics. + + def filled_attributes(self) -> List[str]: + filled_attrs = [] + for attr, value in self.__dict__.items(): + if value is not None: + filled_attrs.append(attr) + return filled_attrs + +@dataclass +class Layer: + weight_info: WeightInfo + metrics: Dict[str, List[Metric]] = field(default_factory=dict) + + def metrics_with_attribute(self, attribute: str) -> List[str]: + return [name for name, metric in self.metrics.items() if attribute in metric[0].filled_attributes()] + + def add_metric(self, metric: Metric, name: str): + if name not in self.metrics.keys(): + self.metrics[name] = [metric] + else: + self.metrics[name].append(metric) + + def add_metric_list(self, metric_list: List[Metric], name: str): + for metric in metric_list: + self.add_metric(metric, name) + +def expand_to_fit(all_layer_names: List[str], values: List[float], subset_layer_names: List[str]) -> List[float]: + """ + Expands a list of values to fit a larger list of layer names, filling in missing values with None. + + Args: + all_layer_names (List[str]): List of all layer names. + values (List[float]): List of values to expand. + subset_layer_names (List[str]): List of layer names that the values correspond to. + + Returns: + List[float]: Expanded list of values, with None values for missing layers. + """ + result = [None] * len(all_layer_names) + subset_dict = dict(zip(subset_layer_names, values)) + + for i, layer in enumerate(all_layer_names): + if layer in subset_dict: + result[i] = subset_dict[layer] + + return result + +class Results: + # Class to store the statistics for each layer + def __init__(self): + self.layers: Dict[str, Layer] = {} + + def add_layer(self, layer: Layer, name: str): + if name not in self.layers.keys(): + self.layers[name] = layer + + def get_metric(self, layer_name: str, metric_name: str) -> Metric: + return self.get_layer(layer_name, metric_name) + + def get_lineplot_data(self, metric_name: str): + means, stds = defaultdict(list), defaultdict(list) + layers = [] + + for name, layer in self.layers.items(): + if metric_name in layer.metrics: + for model_result in layer.metrics[metric_name]: + model_ref = model_result.model_ref if model_result.model_ref else 'all' + means[model_ref].append(model_result.mean_std.mean) + stds[model_ref].append(model_result.mean_std.std) + layers.append(name) + + means_list, stds_list, model_references = list(means.values()), list(stds.values()), list(means.keys()) + for i, model_ref in enumerate(model_references): + means_list[i] = expand_to_fit(all_layer_names=list(self.layers.keys()), values=means_list[i], subset_layer_names=layers) + stds_list[i] = expand_to_fit(all_layer_names=list(self.layers.keys()), values=stds_list[i], subset_layer_names=layers) + + return means_list, stds_list, model_references + + def available_metrics(self) -> Dict[str, Dict[str, Any]]: + all_metrics = set() + for layer in self.layers.values(): + all_metrics.update(layer.metrics.keys()) + + metric_info = {} + for metric in all_metrics: + info = { + 'layers': [], + 'has_mean_std': False, + 'has_histogram': False, + 'has_heatmap': False, + 'has_model_ref': False + } + for layer_name, layer in self.layers.items(): + if metric in layer.metrics: + info['layers'].append(layer_name) + for m in layer.metrics[metric]: + if m.mean_std: + info['has_mean_std'] = True + if m.histogram: + info['has_histogram'] = True + if m.heatmap: + info['has_heatmap'] = True + if m.model_ref: + info['has_model_ref'] = True + metric_info[metric] = info + return metric_info + + def print_metric_summary(self): + metric_info = self.available_metrics() + print("Available Metrics Summary:") + for metric, info in metric_info.items(): + print(f"\nMetric: {metric}") + # print(f" Available in layers: {', '.join(info['layers'])}") + print(f" Has mean/std: {'Yes' if info['has_mean_std'] else 'No'}") + print(f" Has histogram: {'Yes' if info['has_histogram'] else 'No'}") + print(f" Has heatmap: {'Yes' if info['has_heatmap'] else 'No'}") + print(f" Has model reference: {'Yes' if info['has_model_ref'] else 'No'}") \ No newline at end of file diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 9304dd1a..047b5fdc 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -8,11 +8,12 @@ import dash from dash import dcc, html from dash.dependencies import Input, Output, State -from mergekit.metric_methods.all_metrics import Layer, Results +from mergekit.metric_methods.all_metrics import Layer +from mergekit.metric_methods.base import Results -class MetricsHandler: +class ResultsHandler: """ - Object to handle metrics output. Allows for easy plotting of metrics by layer and across layers. + Object to handle metrics results. Allows for easy plotting of metrics by layer and across layers. Input: Use the load_metrics method to load the metrics into the handler. @@ -20,52 +21,29 @@ class MetricsHandler: Attributes: all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cossim_mean': 0.5, 'cossim_std': 0.1}} - stat_names: List of names of all statistics available. e.g. ['cossim_mean', 'cossim_std'] + metric_names: List of names of all statistics available. e.g. ['cossim_mean', 'cossim_std'] layer_names: List of layer names. Methods: load_metrics: Load the metrics into the handler. - stats_at_layer: Get the metrics for a specific layer. - info_at_layer: Get the weight info for a specific layer. + # stats_at_layer: Get the metrics for a specific layer. + # info_at_layer: Get the weight info for a specific layer. line_plot: Plot a line plot of the chosen stat across layers. - plot_node_hist: Plot a histogram of the stat for a specific layer. + plotly_layer_histogram: Plot a histogram of the stat for a specific layer. """ - def __init__(self): - self.layer_names: List[str] = [] + def __init__(self, metrics: List[Tuple[Task, Layer]]): self.results = Results() - self.stat_names: List[str] = [] + self.load_metrics(metrics) def load_metrics(self, metrics: List[Tuple[Task, Layer]]): + self.metric_names = [] for task, metric in metrics: if metric is not None: self.results.add_layer(metric, name=task.weight_info.name) - self.stat_names.extend(list(metric.metrics.keys())) + self.metric_names.extend(list(metric.metrics.keys())) self.layer_names = list(self.results.layers.keys()) - self.stat_names = list(set(self.stat_names)) - - def stats_at_layer(self, layer_name: str) -> Dict[str, Any]: - if layer_name not in self.results.layers: - raise ValueError(f"Layer {layer_name} not found") - return self.results.layers[layer_name] - - def info_at_layer(self, layer_name: str): - if layer_name not in self.results.layers: - raise ValueError(f"Layer {layer_name} not found") - return self.results.layers[layer_name].weight_info - - def line_plot(self, stat: str, save_to:Optional[str]=None, **kwargs): - fig, ax = plt.subplots() + self.metric_names = list(set(self.metric_names)) - ax_kwargs = ['ylabel', 'title', 'ylim', 'xticklabels'] - plot_kwargs = {k: v for k, v in kwargs.items() if k not in ax_kwargs} - - self._line_plot(ax, stat, plot_kwargs) - self._set_plot_attributes(ax, stat, ax_kwargs, **kwargs) - if save_to: - plt.savefig(save_to) - plt.show() - plt.close() - def categorise_layers(self, layer_names): # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config categories = [] @@ -74,11 +52,34 @@ def categorise_layers(self, layer_names): categories.append('Attention Block') elif 'mlp' in name: categories.append('MLP') + elif 'layernorm' in name: + categories.append('LayerNorm') else: categories.append('Other') return categories + + def plotly_line_plots(self, metric_name:str): + if metric_name not in self.metric_names: + print(f"Stat {metric_name} not found") + return [] + + layer_names = self.layer_names + means, stds, model_refs = self.results.get_lineplot_data(metric_name) + traces = [] + available_shapes = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] + + if len(model_refs) > 1: + unique_categories = [str(ref) for ref in model_refs] + layer_categories = [[str(model_refs[i])]*len(layer_names) for i in range(len(model_refs))] + else: + layer_categories = [self.categorise_layers(layer_names)] + unique_categories = list(set(layer_categories[0])) + for i, model_ref in enumerate(model_refs): + traces.extend(self._plotly_line_plot(layer_names, means[i], stds[i], layer_categories[i], unique_categories, shape=available_shapes[i%len(available_shapes)])) + + return traces, layer_names - def plotly_line_plot(self, stat: str, **kwargs): + def _plotly_line_plot(self, x_values, means, stds, layer_categories, unique_categories, shape:str='circle', **kwargs): """ Plot the stat values across layers using Plotly. @@ -89,16 +90,6 @@ def plotly_line_plot(self, stat: str, **kwargs): List[go.Scatter]: List of Plotly Scatter objects. """ - if stat not in self.stat_names: - print(f"Stat {stat} not found") - return - - means = [self.results.layers[layer].metrics[stat].mean_std.mean for layer in self.layer_names] - stds = [self.results.layers[layer].metrics[stat].mean_std.std for layer in self.layer_names] - - layer_categories = self.categorise_layers(self.layer_names) - unique_categories = list(set(layer_categories)) - # Assign a unique color to each category cmap = plt.get_cmap('Set1', len(unique_categories)) colors = [mcolors.to_hex(cmap(i)) for i in range(len(unique_categories))] @@ -110,9 +101,11 @@ def plotly_line_plot(self, stat: str, **kwargs): for category in unique_categories: y_category = [means[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] std_category = [stds[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + if all([y is None for y in y_category]): + continue traces.append(go.Scatter( - x=self.layer_names, + x=x_values, y=y_category, error_y=dict( type='data', @@ -121,39 +114,31 @@ def plotly_line_plot(self, stat: str, **kwargs): ), mode='markers', name=category, - marker=dict(color=category_styles[category]) + marker=dict(color=category_styles[category]), + marker_symbol=shape )) return traces - def _line_plot(self, ax, stat:str, plot_kwargs: Optional[Dict[str, Any]] = {}): - """ - Plot the stat values with optional error bars. - - Args: - ax: The matplotlib Axes object. - stat (str): The name of the stat to plot. - plot_kwargs: Additional keyword arguments for plotting. - """ - means = [self.results.layers[layer].metrics[stat].mean for layer in self.layer_names] - stds = [self.results.layers[layer].metrics[stat].std for layer in self.layer_names] - ax.errorbar(self.layer_names, means, yerr=stds, fmt='-o', **plot_kwargs) - - def plot_node_heatmap(self, layer_name:str, stat:str): + def plotly_layer_heatmap(self, layer_name:str, metric_name:str): """ Plot the stat values as a heatmap. Args: layer_name (str): The name of the layer. - stat (str): The name of the stat to plot. + metric_name (str): The name of the stat to plot. Returns: go.Heatmap: Plotly Heatmap object. """ - heatmap = self.results.layers[layer_name].metrics[stat].heatmap.data + metrics_list = self.results.layers[layer_name].metrics[metric_name] + if len(metrics_list) > 1: + raise Warning(f"Multiple heatmaps found for {metric_name} at layer {layer_name}. Using the first one.") + + heatmap = self.results.layers[layer_name].metrics[metric_name][0].heatmap.data - return go.Heatmap( + return [go.Heatmap( z=heatmap, colorscale='RdBu' - ) + )] def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): """ @@ -175,84 +160,51 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): if kwarg in kwargs: getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) - def plot_node_hist(self, layer_name: str, stat: str): - - hist = self.stats_at_layer(layer_name).metrics[stat].histogram - count, edges, widths = hist.count, hist.edges, hist.widths - return go.Bar( - x=edges[:-1], - y=count, - width=widths, - marker=dict( - color='blue', - line=dict( - color='black', - width=1 - ) - ) - ) + def plotly_layer_histogram(self, layer_name: str, metric_name: str): + metric_list = self.results.layers[layer_name].metrics[metric_name] + colors = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] # (X) + + traces = [] + for i, metric in enumerate(metric_list): + hist = metric.histogram + count, edges, widths = hist.count, hist.edges, hist.widths + traces.append(go.Bar( + x=edges[:-1], + y=count, + width=widths, + marker=dict( + color=colors[i], + opacity=0.75, + line=dict( + color='black', + width=1 + ) + ), + name=str(metric.model_ref) + )) + return traces - def node_plot_options(self, node): - layer = self.results.layers[node] + def layer_plot_options(self, layer_name: str): + layer = self.results.layers[layer_name] return [ {"label": f"{metric.title()} Histogram", "value": [metric, 'histogram']} - for metric in layer.metrics_with_property('histogram') + for metric in layer.metrics_with_attribute('histogram') ] + [ {"label": f"{metric.title()} Heatmap", "value": [metric, 'heatmap']} - for metric in layer.metrics_with_property('heatmap') + for metric in layer.metrics_with_attribute('heatmap') ] - -class ModelGraph: - def __init__(self, metrics: List[Tuple['Task', Dict[str, Any]]]): - self.metric_handler = MetricsHandler() - self.metric_handler.load_metrics(metrics) - self.hierarchy = [] - self.common_parts = self._find_common_parts() - self.graph = nx.DiGraph() - self._parse_task_names() - - def _find_common_parts(self) -> List[str]: - """ - Find common parts in all task names. - """ - common_parts = None - for task_name in self.metric_handler.results.layers.keys(): - parts = task_name.split('.') - if common_parts is None: - common_parts = set(parts) - else: - common_parts.intersection_update(parts) - - return list(common_parts) - def _parse_task_names(self): - for task_name in self.metric_handler.results.layers.keys(): - self.hierarchy.append(task_name) - - def _add_nodes_and_edges(self, hierarchy): - # Current implementation builds linear graph - # Parallel paths (heads, skips) not yet supported - prev = None - for name in hierarchy: - self.graph.add_node(name) - if prev: - self.graph.add_edge(prev, name) - prev = name - - def construct_graph(self): - self._add_nodes_and_edges(self.hierarchy) - -def create_app(nn_graph): +def create_app(results_handler): app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) app.layout = html.Div([ create_header(), - create_line_plot_section(nn_graph), - create_node_metrics_section() + create_line_plot_section(results_handler), + create_layer_metrics_section() ]) - register_callbacks(app, nn_graph) + register_callbacks(app, results_handler) return app @@ -260,76 +212,104 @@ def create_header(): return html.H1('Network Weights Similarity Visualization', style={'textAlign': 'center', 'padding': '20px'}) -def create_line_plot_section(nn_graph): +def create_line_plot_section(results_handler): return html.Div([ dcc.Dropdown( id='line-plot-dropdown', - options=[{'label': metric.replace('_', ' ').title(), 'value': metric} - for metric in nn_graph.metric_handler.stat_names], + options=[{'label': metric_name.replace('_', ' ').title(), 'value': metric_name} + for metric_name in results_handler.metric_names], value='cossim', style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} ), dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}) ], className='container-fluid') -def create_node_metrics_section(): +def create_layer_metrics_section(): return html.Div([ - html.H3('Node Metrics', style={'textAlign': 'center'}), + html.H3('Layer Metrics', style={'textAlign': 'center'}), dcc.Dropdown( id='metric-dropdown', options=[], style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'}, value=None ), - dcc.Graph(id='node-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}) + dcc.Graph(id='layer-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}) ], className='container-fluid') - -def register_callbacks(app, nn_graph): +def default_option(options, current_value): + if not options: + return None + if current_value.lower() in [o.lower() for o in options]: + return current_value + for option in options: + if option.lower() in current_value.lower() or current_value.lower() in option.lower(): + return option + return options[0] + +def register_callbacks(app, results_handler): @app.callback( Output('metric-dropdown', 'options'), Output('metric-dropdown', 'value'), - Input('line-plot', 'clickData') + Input('line-plot', 'clickData'), + Input('line-plot-dropdown', 'value') ) - def update_metric_dropdown_options(clickData): + def update_metric_dropdown_options(clickData, selected_metric): if not clickData: return [], None try: - node_name = clickData['points'][0]['x'] - options = nn_graph.metric_handler.node_plot_options(node_name) - return options, options[0]['value'] if options else ([], None) + layer_name = clickData['points'][0]['x'] + options = results_handler.layer_plot_options(layer_name) + default_label = default_option(list(map(lambda x: x['label'], options)), selected_metric) + default_value = [option['value'] for option in options if option['label'] == default_label][0] + return options, default_value except (KeyError, IndexError, AttributeError) as e: print(f"Error processing clickData: {e}") return [], None @app.callback( - Output('node-details-plot', 'figure'), + Output('layer-details-plot', 'figure'), Input('metric-dropdown', 'value'), State('line-plot', 'clickData') ) - def display_node_data(selected_metric, clickData): + def display_layer_data(selected_metric, clickData): if not clickData: return go.Figure() try: - node_name = clickData['points'][0]['x'] + layer_name = clickData['points'][0]['x'] if not selected_metric: - selected_metric = nn_graph.metric_handler.node_plot_options(node_name)[0]['value'] + selected_metric = results_handler.layer_plot_options(layer_name)[0]['value'] metric_name, plot_type = selected_metric - if 'histogram' in plot_type.lower(): - trace = nn_graph.metric_handler.plot_node_hist(node_name, stat=metric_name) - return create_figure(trace, f"Histogram for {node_name} | {metric_name}", "Metric", "Value") - elif 'heatmap' in plot_type.lower(): - trace = nn_graph.metric_handler.plot_node_heatmap(layer_name=node_name, stat=metric_name) - return create_figure(trace, f"Heatmap for {node_name} | {metric_name}", "Model 1 Head", "Model 0 Head") + # Define default axis titles + xaxis_title = "Value" + yaxis_title = "Count" + + # Update axis titles if plot_type is 'heatmap' + if plot_type.lower() == "heatmap": + xaxis_title = "Model 1 Head" + yaxis_title = "Model 0 Head" - return go.Figure() + plot_function = { + 'histogram': results_handler.plotly_layer_histogram, + 'heatmap': results_handler.plotly_layer_heatmap + }.get(plot_type.lower(), + lambda *args, **kwargs: go.Figure()) # Defaults to *function* to produce empty figure + + traces = plot_function(layer_name=layer_name, + metric_name=metric_name) + + return create_figure(traces=traces, + title=f"{plot_type.title()} for {layer_name} | {metric_name}", + xaxis_title=xaxis_title, + yaxis_title=yaxis_title + ) + except (KeyError, IndexError, AttributeError) as e: - print(f"Error processing node data: {e}") + print(f"Error processing layer data: {e}") return go.Figure() @app.callback( @@ -339,11 +319,32 @@ def display_node_data(selected_metric, clickData): def update_line_plot(selected_metric): if not selected_metric: return go.Figure() - return go.Figure(data=nn_graph.metric_handler.plotly_line_plot(selected_metric)) + + traces, layer_names = results_handler.plotly_line_plots(metric_name=selected_metric) + fig = go.Figure() + for trace in traces: + fig.add_trace(trace) + + fig.update_layout( + title=f"{selected_metric.replace('_', ' ').title()} Across Layers", + xaxis=dict( + title='Layer', + tickvals=list(range(len(layer_names))), + ticktext=layer_names + ), + yaxis=dict(title=selected_metric.replace('_', ' ').title()) + ) + return fig -def create_figure(trace, title, xaxis_title, yaxis_title): - return go.Figure(data=[trace], layout=go.Layout( +def create_figure(traces, title, xaxis_title, yaxis_title): + fig = go.Figure() + for trace in traces: + fig.add_trace(trace) + + fig.update_layout( title=title, - xaxis_title=xaxis_title, - yaxis_title=yaxis_title - )) + xaxis=dict(title=xaxis_title), + yaxis=dict(title=yaxis_title) + ) + + return fig From 50a57167bcf35e16e26557d5dbcdcbd6bc943209 Mon Sep 17 00:00:00 2001 From: "es3e20@soton.ac.uk" Date: Wed, 26 Jun 2024 16:51:00 +0100 Subject: [PATCH 29/64] Add intra-layer metrics, completed implementation of changes from prev commit --- examples/metrics-llama-1v2.yml | 2 + examples/metrics-small.yml | 2 + mergekit/metric_methods/all_metrics.py | 208 ++++++++++++++++++++----- mergekit/scripts/run_metrics.py | 9 +- 4 files changed, 179 insertions(+), 42 deletions(-) diff --git a/examples/metrics-llama-1v2.yml b/examples/metrics-llama-1v2.yml index 58641b85..9cca1ae8 100644 --- a/examples/metrics-llama-1v2.yml +++ b/examples/metrics-llama-1v2.yml @@ -3,4 +3,6 @@ models: - model: TheBloke/Llama-2-7B-fp16 metric_method: all +parameters: + intra_model_metrics: true dtype: float32 diff --git a/examples/metrics-small.yml b/examples/metrics-small.yml index 3f13c4d4..6c6ec1a9 100644 --- a/examples/metrics-small.yml +++ b/examples/metrics-small.yml @@ -3,4 +3,6 @@ models: - model: BEE-spoke-data/smol_llama-220M-openhermes metric_method: all +parameters: + intra_model_metrics: true dtype: float32 diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index fb17c26c..80b2e123 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -13,21 +13,16 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from mergekit.architecture import WeightInfo from mergekit.common import ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod - +from mergekit.metric_methods.base import MetricMethod, MeanStd, Heatmap, Histogram, Metric, Layer import torch import torch.nn.functional as F - import numpy as np - - -from dataclasses import dataclass, field from typing import Dict, List, Any def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): @@ -105,7 +100,7 @@ def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: return similarity_matrix -# Metric functions +# Tensor Comparisons (Require exactly 2 tensors) def smape( tensors: List[torch.Tensor], **_kwargs @@ -155,8 +150,6 @@ def scale( norm_0 = torch.norm(tensors[0], dim=1) norm_1 = torch.norm(tensors[1], dim=1) - res = {} - scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) if return_heatmap: @@ -204,12 +197,78 @@ def mse( # Tensor Analysis (number of tensors can vary) +def weight_magnitude(tensors: List[torch.Tensor], model_refs: List[ModelReference]) -> List[Metric]: + output = [] + for tensor, model_reference in zip(tensors, model_refs): + weight_magnitudes = torch.abs(tensor.flatten()) + hist_info = compute_histogram(weight_magnitudes, 100) + output.append(Metric( + histogram=Histogram(count=hist_info[0], + edges=hist_info[1], + widths=hist_info[2] + ), + mean_std=MeanStd(mean=weight_magnitudes.mean().item(), + std=weight_magnitudes.std().item()), + model_ref=model_reference + )) + return output + +def numerical_rank(tensors: List[torch.Tensor], model_refs: List[ModelReference], epsilon: float = 1e-5) -> List[Metric]: + """ + Computes the numerical rank of the representations matrix X based on the singular values + of its sample covariance matrix. The rank is determined as the number of singular values + above a threshold. The threshold is defined as the highest singular value times a given epsilon. + + Parameters: + - X : torch.Tensor + The representations matrix from which the sample covariance matrix will be computed. + - epsilon : float, optional + The factor to multiply with the highest singular value to set the threshold (default is 1e-3). + - flip : bool, optional - allows transpose for efficient computation. False only used in testing + Returns: + - int + The numerical rank of the matrix. + + Implemented according to description in the paper: + The Tunnel Effect: Building Data Representations in Deep Neural Networks + https://arxiv.org/pdf/2305.19753.pdf + + """ + output = [] + for tensor, model_reference in zip(tensors, model_refs): + + # Center the data by subtracting the mean + X_centered = tensor - torch.mean(tensor, dim=0) + X_std = torch.std(X_centered, dim=0, unbiased=False) + X_centered /= X_std + + # Compute the sample covariance matrix + covariance_matrix = X_centered.t() @ X_centered / (tensor.shape[0] - 1) + # Compute singular values using SVD on the covariance matrix + U, singular_values, V = torch.svd(covariance_matrix) + # Determine the threshold + threshold = singular_values[0] * epsilon + # Count singular values greater than the threshold + num_rank = torch.sum(singular_values > threshold).item() + + value = int(num_rank) + + output.append( + Metric( + model_ref=model_reference, + mean_std=MeanStd( + mean=value, + std=None), + )) + + return output # Tasks class MLPTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo + intra_model_metrics: bool = False def uses_accelerator(self) -> bool: return True @@ -222,15 +281,20 @@ def execute( ) -> torch.Tensor: weights = list(tensors.values()) validate_tensors(weights, self.weight_info, expected_tensors=2) - out = Layer(metrics={}, + layer_results = Layer(metrics={}, weight_info=self.weight_info) - out.metrics['cossim'] = cossim(weights, return_heatmap=False) - out.metrics['smape'] = smape(weights) - out.metrics['scale'] = scale(weights, return_heatmap=False) - out.metrics['mse'] = mse(weights, return_heatmap=False) # Highly inefficient + layer_results.add_metric(cossim(weights, return_heatmap=False), name = 'cossim') + layer_results.add_metric(smape(weights), name = 'smape') + layer_results.add_metric(scale(weights, return_heatmap=False), name = 'scale') + layer_results.add_metric(mse(weights, return_heatmap=False), name = 'mse') - return out + if self.intra_model_metrics: + model_refs = list(tensors.keys()) + layer_results.add_metric_list(metric_list=weight_magnitude(weights, model_refs), name='weight_magnitude') + layer_results.add_metric_list(metric_list=numerical_rank(weights, model_refs), name='numerical_rank') + + return layer_results def group_label(self) -> Optional[str]: return self.gather_tensors.group_label() @@ -239,6 +303,7 @@ class AttnTask(Task[torch.Tensor]): weights: Dict[str, GatherTensors] weight_infos: Dict[str, WeightInfo] weight_info: WeightInfo + intra_model_metrics: bool = False def uses_accelerator(self) -> bool: return True @@ -252,10 +317,18 @@ def execute( ) -> torch.Tensor: # Add metrics for attention weights - models = list(q_proj.keys()) - - k_proj_0, v_proj_0, q_proj_0, o_proj_0 = group_attn_head_weights(k_proj[models[0]], q_proj[models[0]], v_proj[models[0]], o_proj[models[0]], self.weight_info) - k_proj_1, v_proj_1, q_proj_1, o_proj_1 = group_attn_head_weights(k_proj[models[1]], q_proj[models[1]], v_proj[models[1]], o_proj[models[1]], self.weight_info) + model_references = list(q_proj.keys()) + + k_proj_0, v_proj_0, q_proj_0, o_proj_0 = group_attn_head_weights(k_proj[model_references[0]], + q_proj[model_references[0]], + v_proj[model_references[0]], + o_proj[model_references[0]], + self.weight_info) + k_proj_1, v_proj_1, q_proj_1, o_proj_1 = group_attn_head_weights(k_proj[model_references[1]], + q_proj[model_references[1]], + v_proj[model_references[1]], + o_proj[model_references[1]], + self.weight_info) # Metrics for K, V, Q, O projections @@ -265,22 +338,35 @@ def execute( model_0_heads = torch.cat([k_proj_0, v_proj_0, q_proj_0, o_proj_0], dim=1) model_1_heads = torch.cat([k_proj_1, v_proj_1, q_proj_1, o_proj_1], dim=1) - out = Layer(metrics={}, + layer_results = Layer(metrics={}, weight_info=self.weight_info) - out.metrics['cossim'] = cossim([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True) - out.metrics['smape'] = smape([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)]) - out.metrics['scale'] = scale([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True) - out.metrics['mse'] = mse([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=False) - return out + layer_results.add_metric(cossim([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True), + name = 'cossim') + layer_results.add_metric(smape([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)]), + name = 'smape') + layer_results.add_metric(scale([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True), + name = 'scale') + layer_results.add_metric(mse([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=False), + name = 'mse') + + if self.intra_model_metrics: + + layer_results.add_metric_list( + metric_list=weight_magnitude([model_0_heads, model_1_heads], model_refs=model_references), + name='weight_magnitude' + ) + + + return layer_results def group_label(self) -> Optional[str]: return max([gather_tensor.group_label() for gather_tensor in list(self.weights.values())]) @@ -293,6 +379,36 @@ def __eq__(self, other): return False return self.weight_info == other.weight_info +class LayerNormTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + + tensors = list(tensors.values()) + + assert tensors[0].dim() == 1, "LayerNorm tensors must be 2D" + assert tensors[1].dim() == 1, "LayerNorm tensors must be 2D" + + layer_results = Layer(metrics={}, weight_info=self.weight_info) + + layer_results.add_metric(cossim([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'cossim') + layer_results.add_metric(smape([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)]), name = 'smape') + layer_results.add_metric(scale([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'scale') + layer_results.add_metric(mse([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'mse') + + return layer_results + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + class DummyTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo @@ -311,14 +427,16 @@ def execute( def group_label(self) -> Optional[str]: return self.gather_tensors.group_label() + +from mergekit.merge_methods.base import ConfigParameterDef + # Metric method - class AllMetric(MetricMethod): attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} attn_info_dict: Optional[Dict[str, WeightInfo]] = {} - attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] # hard-coded for now block_count: Optional[int] = 0 + def make_task( self, *, @@ -329,16 +447,26 @@ def make_task( ) -> Task: if 'self_attn' in output_weight.name: - return self.group_attn_heads(tensors, output_weight) + return self.group_attn_heads(tensors, output_weight, parameters) elif 'mlp' in output_weight.name: return MLPTask( gather_tensors=tensors, weight_info=output_weight, intra_model_metrics=parameters['intra_model_metrics'] ) + elif 'layernorm' in output_weight.name: + return LayerNormTask(gather_tensors=tensors, weight_info=output_weight) else: # Executor expects a task to be returned return DummyTask(gather_tensors=tensors, weight_info=output_weight) + + def group_attn_heads(self, tensors: GatherTensors, output_weight: WeightInfo, parameters: Dict[str, Any]): + # collect all attention weights + for part in self.attn_parts: # also check only one key + if part in output_weight.name: + assert self.attn_weight_dict.get(part) is None, f"Duplicate attention part {part}" + self.attn_weight_dict[part] = tensors + self.attn_info_dict[part] = output_weight # if all attention weights are collected, create attention task if set(list(self.attn_weight_dict.keys())) == set(self.attn_parts): @@ -353,9 +481,15 @@ def make_task( num_attention_heads=int(infos['k_proj'].num_attention_heads) ) self.block_count += 1 - return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) + return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info, intra_model_metrics=parameters['intra_model_metrics']) else: # Executor expects a task to be returned return DummyTask(gather_tensors=tensors, weight_info=output_weight) + + + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef(name="intra_model_metrics", required=False, default_value=False), + ] diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py index 1090e63e..d5452808 100644 --- a/mergekit/scripts/run_metrics.py +++ b/mergekit/scripts/run_metrics.py @@ -5,7 +5,7 @@ from mergekit.config import MergeConfiguration from mergekit.merge import MergeOptions from mergekit.merge import run_merge -from mergekit.plot_tools.plot_tools import ModelGraph, create_app +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler @click.command() @click.option('--output_path', default="./merged", help='folder to store the result in.') @@ -17,7 +17,7 @@ def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory) with open(config_yml, "r", encoding="utf-8") as fp: metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) - metrics = run_merge( + metrics_results = run_merge( metric_config, out_path=output_path, options=MergeOptions( @@ -28,10 +28,9 @@ def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory) ), ) - nn_graph = ModelGraph(metrics) - nn_graph.construct_graph() + handler = ResultsHandler(metrics_results) - app = create_app(nn_graph=nn_graph) + app = create_app(results_handler=handler) app.run_server() if __name__ == '__main__': From 73dd3ae0088504dd364047a22c5c430ec3b35f6a Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 4 Jul 2024 15:21:50 +0100 Subject: [PATCH 30/64] restructure metrics for modularity --- mergekit/metric_methods/all_metrics.py | 185 +--------------------- mergekit/metric_methods/metrics.py | 202 +++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 182 deletions(-) create mode 100644 mergekit/metric_methods/metrics.py diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index 80b2e123..453cb39a 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -19,11 +19,10 @@ from mergekit.common import ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors -from mergekit.metric_methods.base import MetricMethod, MeanStd, Heatmap, Histogram, Metric, Layer +from mergekit.metric_methods.base import MetricMethod, Layer import torch -import torch.nn.functional as F -import numpy as np from typing import Dict, List, Any +from mergekit.metric_methods.metrics import cossim, smape, scale, mse, weight_magnitude, numerical_rank def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): """Validate tensor shapes and count.""" @@ -45,7 +44,7 @@ def ungroup_tensor(input_tensor: torch.Tensor, gqa_groups: int) -> torch.Tensor: for i in range(gqa_groups): ungrouped_tensor[i*rows:(i+1)*rows] = input_tensor[i].expand(rows, -1) - return ungrouped_tensor + return ungrouped_tensor.to(input_tensor.device) def restructure_tensor(input_tensor: torch.Tensor, num_rows: int) -> torch.Tensor: """ @@ -85,184 +84,6 @@ def group_attn_head_weights(k_proj: torch.Tensor, return k_proj, v_proj, q_proj, o_proj -def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: - bin_counts, bin_edges = np.histogram(tensor.numpy(), bins=n_bins) - bin_widths = np.diff(bin_edges) - return bin_counts, bin_edges, bin_widths - -def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: - # Normalize the rows of both matrices - A_norm = A / A.norm(dim=1, keepdim=True) - B_norm = B / B.norm(dim=1, keepdim=True) - - # Compute the cosine similarity matrix - similarity_matrix = torch.mm(A_norm, B_norm.t()) - - return similarity_matrix - -# Tensor Comparisons (Require exactly 2 tensors) - -def smape( - tensors: List[torch.Tensor], **_kwargs -) -> Metric: - """Symmetric Mean Absolute Percentage Error (smape).""" - - numerator = torch.abs(tensors[0] - tensors[1]) - denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) - smape = torch.mean(torch.div(numerator, denominator), dim=1) - - hist_info = compute_histogram(smape, 100) - - return Metric( - histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), - mean_std=MeanStd(mean=smape.mean().item(), std=smape.std().item()) - ) - -def cossim( - tensors: List[torch.Tensor], return_heatmap=False, **_kwargs -) -> Metric: - """Cosine similarity""" - cossim = F.cosine_similarity(tensors[0], tensors[1], dim=1) - - if return_heatmap: - heatmap = cossim_heatmap(tensors[0], tensors[1]) - - assert torch.isclose(cossim, cossim, atol=1e-6).all(), "NaNs in cosine similarity" - assert torch.isclose(cossim, cossim_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-2).all(), "Diagonal elements of cosine similarity matrix do not match" - - hist_info = compute_histogram(cossim, 100) - return Metric( - histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), - mean_std=MeanStd(mean=cossim.mean().item(), std=cossim.std().item()), - heatmap=Heatmap(data=heatmap) if return_heatmap else None - ) - -def scale( - tensors: List[torch.Tensor], return_heatmap=False, **_kwargs -) -> Metric: - """ - Scale difference: ratio of absolute difference to average scale. - Complementary to cosine similarity, which measures the angle between two vectors and is invariant to scale. - - values close to 0 indicate that the scales of the two vectors are similar - """ - - norm_0 = torch.norm(tensors[0], dim=1) - norm_1 = torch.norm(tensors[1], dim=1) - - scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) - - if return_heatmap: - norm_0 = norm_0.unsqueeze(1) # shape becomes [num_heads, 1] - norm_1 = norm_1.unsqueeze(0) # shape becomes [1, num_heads] - - # Compute the scale difference between each pair of heads by broadcasting - heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1 + 1e-10) / 2) - - assert torch.isclose(scale_diff, heatmap.diagonal(), atol=1e-4).all(), "Diagonal elements of scale difference matrix do not match" - - hist_info = compute_histogram(scale_diff, 100) - - return Metric( - histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), - mean_std=MeanStd(mean=scale_diff.mean().item(), std=scale_diff.std().item()), - heatmap=Heatmap(data=heatmap) if return_heatmap else None - ) - -def mse( - tensors: List[torch.Tensor], return_heatmap: bool =False, **_kwargs -) -> Metric: - """Mean squared error (MSE).""" - if return_heatmap: - # Expand dimensions for broadcasting - tensors_0_exp = tensors[0].unsqueeze(1) # shape becomes [num_heads, 1, ...] - tensors_1_exp = tensors[1].unsqueeze(0) # shape becomes [1, num_heads, ...] - - # Compute squared differences - diffs = (tensors_0_exp - tensors_1_exp) ** 2 - - # Compute mean over all dimensions except the first two - heatmap = diffs.mean(dim=tuple(range(2, diffs.dim()))).numpy() - - squared_diff = (tensors[0] - tensors[1]) ** 2 - mse_per_neuron = torch.mean(squared_diff, dim=1) - - hist_info = compute_histogram(mse_per_neuron, 100) - - return Metric( - histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), - mean_std=MeanStd(mean=mse_per_neuron.mean().item(), std=mse_per_neuron.std().item()), - heatmap=Heatmap(data=heatmap) if return_heatmap else None - ) - -# Tensor Analysis (number of tensors can vary) - -def weight_magnitude(tensors: List[torch.Tensor], model_refs: List[ModelReference]) -> List[Metric]: - output = [] - for tensor, model_reference in zip(tensors, model_refs): - weight_magnitudes = torch.abs(tensor.flatten()) - hist_info = compute_histogram(weight_magnitudes, 100) - output.append(Metric( - histogram=Histogram(count=hist_info[0], - edges=hist_info[1], - widths=hist_info[2] - ), - mean_std=MeanStd(mean=weight_magnitudes.mean().item(), - std=weight_magnitudes.std().item()), - model_ref=model_reference - )) - return output - -def numerical_rank(tensors: List[torch.Tensor], model_refs: List[ModelReference], epsilon: float = 1e-5) -> List[Metric]: - """ - Computes the numerical rank of the representations matrix X based on the singular values - of its sample covariance matrix. The rank is determined as the number of singular values - above a threshold. The threshold is defined as the highest singular value times a given epsilon. - - Parameters: - - X : torch.Tensor - The representations matrix from which the sample covariance matrix will be computed. - - epsilon : float, optional - The factor to multiply with the highest singular value to set the threshold (default is 1e-3). - - flip : bool, optional - allows transpose for efficient computation. False only used in testing - Returns: - - int - The numerical rank of the matrix. - - Implemented according to description in the paper: - The Tunnel Effect: Building Data Representations in Deep Neural Networks - https://arxiv.org/pdf/2305.19753.pdf - - """ - output = [] - for tensor, model_reference in zip(tensors, model_refs): - - # Center the data by subtracting the mean - X_centered = tensor - torch.mean(tensor, dim=0) - X_std = torch.std(X_centered, dim=0, unbiased=False) - X_centered /= X_std - - # Compute the sample covariance matrix - covariance_matrix = X_centered.t() @ X_centered / (tensor.shape[0] - 1) - # Compute singular values using SVD on the covariance matrix - U, singular_values, V = torch.svd(covariance_matrix) - # Determine the threshold - threshold = singular_values[0] * epsilon - # Count singular values greater than the threshold - num_rank = torch.sum(singular_values > threshold).item() - - value = int(num_rank) - - output.append( - Metric( - model_ref=model_reference, - mean_std=MeanStd( - mean=value, - std=None), - )) - - return output - # Tasks class MLPTask(Task[torch.Tensor]): diff --git a/mergekit/metric_methods/metrics.py b/mergekit/metric_methods/metrics.py new file mode 100644 index 00000000..c19b7c5c --- /dev/null +++ b/mergekit/metric_methods/metrics.py @@ -0,0 +1,202 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from mergekit.common import ModelReference +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric +import torch +import torch.nn.functional as F +import numpy as np +from typing import List + +# Helper functions + +def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: + bin_counts, bin_edges = np.histogram(tensor.cpu().numpy(), bins=n_bins) + bin_widths = np.diff(bin_edges) + return bin_counts, bin_edges, bin_widths + +def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + # Normalize the rows of both matrices + A_norm = A / A.norm(dim=1, keepdim=True) + B_norm = B / B.norm(dim=1, keepdim=True) + + # Compute the cosine similarity matrix + similarity_matrix = torch.mm(A_norm, B_norm.t()) + + return similarity_matrix + + +# Tensor Comparisons (Require exactly 2 tensors) + +def smape( + tensors: List[torch.Tensor], **_kwargs +) -> Metric: + """Symmetric Mean Absolute Percentage Error (smape).""" + + numerator = torch.abs(tensors[0] - tensors[1]) + denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) + smape = torch.mean(torch.div(numerator, denominator), dim=1) + + hist_info = compute_histogram(smape, 100) + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=smape.mean().item(), std=smape.std().item()) + ) + +def cossim( + tensors: List[torch.Tensor], return_heatmap=False, **_kwargs +) -> Metric: + """Cosine similarity""" + cossim = F.cosine_similarity(tensors[0], tensors[1], dim=1) + + if return_heatmap: + heatmap = cossim_heatmap(tensors[0], tensors[1]) + + assert torch.isclose(cossim, cossim, atol=1e-6).all(), "NaNs in cosine similarity" + assert torch.isclose(cossim, cossim_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-2).all(), "Diagonal elements of cosine similarity matrix do not match" + + hist_info = compute_histogram(cossim, 100) + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=cossim.mean().item(), std=cossim.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +def scale( + tensors: List[torch.Tensor], return_heatmap=False, **_kwargs +) -> Metric: + """ + Scale difference: ratio of absolute difference to average scale. + Complementary to cosine similarity, which measures the angle between two vectors and is invariant to scale. + + values close to 0 indicate that the scales of the two vectors are similar + """ + + norm_0 = torch.norm(tensors[0], dim=1) + norm_1 = torch.norm(tensors[1], dim=1) + + scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) + + if return_heatmap: + norm_0 = norm_0.unsqueeze(1) # shape becomes [num_heads, 1] + norm_1 = norm_1.unsqueeze(0) # shape becomes [1, num_heads] + + # Compute the scale difference between each pair of heads by broadcasting + heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1 + 1e-10) / 2) + + assert torch.isclose(scale_diff, heatmap.diagonal(), atol=1e-4).all(), "Diagonal elements of scale difference matrix do not match" + + hist_info = compute_histogram(scale_diff, 100) + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=scale_diff.mean().item(), std=scale_diff.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +def mse( + tensors: List[torch.Tensor], return_heatmap: bool =False, **_kwargs +) -> Metric: + """Mean squared error (MSE).""" + if return_heatmap: + # Expand dimensions for broadcasting + tensors_0_exp = tensors[0].unsqueeze(1) # shape becomes [num_heads, 1, ...] + tensors_1_exp = tensors[1].unsqueeze(0) # shape becomes [1, num_heads, ...] + + # Compute squared differences + diffs = (tensors_0_exp - tensors_1_exp) ** 2 + + # Compute mean over all dimensions except the first two + heatmap = diffs.mean(dim=tuple(range(2, diffs.dim()))).cpu().numpy() + + squared_diff = (tensors[0] - tensors[1]) ** 2 + mse_per_neuron = torch.mean(squared_diff, dim=1) + + hist_info = compute_histogram(mse_per_neuron, 100) + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=mse_per_neuron.mean().item(), std=mse_per_neuron.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +# Tensor Analysis (number of tensors can vary) + +def weight_magnitude(tensors: List[torch.Tensor], model_refs: List[ModelReference]) -> List[Metric]: + output = [] + for tensor, model_reference in zip(tensors, model_refs): + weight_magnitudes = torch.abs(tensor.flatten()) + hist_info = compute_histogram(weight_magnitudes, 100) + output.append(Metric( + histogram=Histogram(count=hist_info[0], + edges=hist_info[1], + widths=hist_info[2] + ), + mean_std=MeanStd(mean=weight_magnitudes.mean().item(), + std=weight_magnitudes.std().item()), + model_ref=model_reference + )) + return output + +def numerical_rank(tensors: List[torch.Tensor], model_refs: List[ModelReference], epsilon: float = 1e-5) -> List[Metric]: + """ + Computes the numerical rank of the representations matrix X based on the singular values + of its sample covariance matrix. The rank is determined as the number of singular values + above a threshold. The threshold is defined as the highest singular value times a given epsilon. + + Parameters: + - X : torch.Tensor + The representations matrix from which the sample covariance matrix will be computed. + - epsilon : float, optional + The factor to multiply with the highest singular value to set the threshold (default is 1e-3). + - flip : bool, optional - allows transpose for efficient computation. False only used in testing + Returns: + - int + The numerical rank of the matrix. + + Implemented according to description in the paper: + The Tunnel Effect: Building Data Representations in Deep Neural Networks + https://arxiv.org/pdf/2305.19753.pdf + + """ + output = [] + for tensor, model_reference in zip(tensors, model_refs): + + # Center the data by subtracting the mean + X_centered = tensor - torch.mean(tensor, dim=0) + X_std = torch.std(X_centered, dim=0, unbiased=False) + X_centered /= X_std + + # Compute the sample covariance matrix + covariance_matrix = X_centered.t() @ X_centered / (tensor.shape[0] - 1) + # Compute singular values using SVD on the covariance matrix + U, singular_values, V = torch.svd(covariance_matrix.cpu()) + # Determine the threshold + threshold = singular_values[0] * epsilon + # Count singular values greater than the threshold + num_rank = torch.sum(singular_values > threshold).item() + + value = int(num_rank) + + output.append( + Metric( + model_ref=model_reference, + mean_std=MeanStd( + mean=value, + std=None), + )) + + return output From ddbb475b35ebed92581c536f586d4c4c8970ae10 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 4 Jul 2024 15:22:36 +0100 Subject: [PATCH 31/64] add load and save functions to base --- mergekit/metric_methods/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index bb57bc97..315e0b18 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -23,6 +23,8 @@ from collections import defaultdict import torch +import pickle + class MetricMethod(MergeMethod): pass @@ -182,4 +184,13 @@ def print_metric_summary(self): print(f" Has mean/std: {'Yes' if info['has_mean_std'] else 'No'}") print(f" Has histogram: {'Yes' if info['has_histogram'] else 'No'}") print(f" Has heatmap: {'Yes' if info['has_heatmap'] else 'No'}") - print(f" Has model reference: {'Yes' if info['has_model_ref'] else 'No'}") \ No newline at end of file + print(f" Has model reference: {'Yes' if info['has_model_ref'] else 'No'}") + + def save(self, path: str): + path = path + '.pkl' if not path.endswith('.pkl') else path + with open(path, 'wb') as f: + pickle.dump(self, f) + + def load(self, path: str): + with open(path, 'rb') as f: + return pickle.load(f) \ No newline at end of file From 36d28b02f646ca68ac21b257c903cab3eb5b3ee6 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 4 Jul 2024 15:57:21 +0100 Subject: [PATCH 32/64] Internal Representaions analysis - first commit --- representations/representation_metrics.py | 138 ++++++++++ representations/store_representations.py | 133 ++++++++++ .../visualise_representation_results.py | 235 ++++++++++++++++++ 3 files changed, 506 insertions(+) create mode 100644 representations/representation_metrics.py create mode 100644 representations/store_representations.py create mode 100644 representations/visualise_representation_results.py diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py new file mode 100644 index 00000000..f80e9894 --- /dev/null +++ b/representations/representation_metrics.py @@ -0,0 +1,138 @@ +#%% +import torch +import h5py +import random +import numpy as np +import click +import yaml +import h5py + +from mergekit.config import MergeConfiguration +from mergekit.merge import MergeOptions +from mergekit.merge import run_merge +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler + +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer +from mergekit.metric_methods.metrics import cossim, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cossim_heatmap +from mergekit.architecture import WeightInfo + +from tqdm import tqdm + +from pathlib import Path + +import torch.nn.functional as F + +class MetricAggregator(): + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): + pass + + def aggregate(self): + pass + + def clear(self): + self.__init__() + +class CosineSimilarity(MetricAggregator): + def __init__(self, device="cpu"): + self.cosine_similarities = torch.tensor([]).to(device) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): + batch_similarities = F.cosine_similarity(batch_a, batch_b, dim=1) + self.cosine_similarities = torch.cat((self.cosine_similarities, batch_similarities), dim=0) + + def aggregate(self): + hist = compute_histogram(self.cosine_similarities, 100) + return Metric( + mean_std=MeanStd( + mean=self.cosine_similarities.mean().item(), + std=self.cosine_similarities.std().item()), + histogram=Histogram( + count=hist[0], + edges=hist[1], + widths=hist[2] + ) + ) + +class MSE(MetricAggregator): + def __init__(self, device="cpu"): + self.square_errors = torch.tensor([]).to(device) + + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): + assert batch_a.size(1) == batch_b.size(1) + batch_square_errors = torch.square(batch_a - batch_b).flatten() + self.square_errors = torch.cat((self.square_errors, batch_square_errors), dim=0) + + # CHECK DIMENSIONALITY (X) + + def aggregate(self): + hist = compute_histogram(self.square_errors, 100) + out = Metric( + mean_std=MeanStd( + mean=self.square_errors.mean().item(), + std=self.square_errors.std().item() + ), + histogram=Histogram( + count=hist[0], + edges=hist[1], + widths=hist[2] + ) + ) + self.clear() + return out + +@click.command() +@click.option('--reps_a_path', + default="NEW_Representations_BEE-spoke-data_smol_llama-220M-GQA_train_4000.h5", + help="path to load first set of representations from.") +@click.option('--reps_b_path', + default="NEW_Representations_BEE-spoke-data_smol_llama-220M-openhermes_train_4000.h5", + help="path to load second set of representations from.") +def main(reps_a_path, reps_b_path): + results = Results() + + device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + + reps_a_path = Path(__file__).parent / reps_a_path + reps_b_path = Path(__file__).parent / reps_b_path + + assert reps_a_path.exists(), f"File not found: {reps_a_path}" + assert reps_b_path.exists(), f"File not found: {reps_b_path}" + + with h5py.File(reps_a_path, 'r') as representations_a, \ + h5py.File(reps_b_path, 'r') as representations_b: + + for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Layers', total=len(representations_a)): + metrics = { + 'Cosine Similarity' : CosineSimilarity(device=device), + 'MSE' : MSE(device=device) + } + + layer_results = Layer(WeightInfo(name=layer_a)) + if layer_a != layer_b: + raise ValueError(f'Layer mismatch: {layer_a} != {layer_b}') + + # Load the representations + layer_representations_a = representations_a[layer_a] + layer_representations_b = representations_b[layer_b] + + for batch_a, batch_b in tqdm(zip(layer_representations_a, layer_representations_b), desc='Batches', total=len(layer_representations_a), leave=False): + batch_a = torch.tensor(layer_representations_a[batch_a][:], device=device) + batch_b = torch.tensor(layer_representations_b[batch_b][:], device=device) + # Calculate the metrics for each batch + for _, metric in metrics.items(): + metric.process_batch(batch_a, batch_b) + + # Aggregate over the batches and add to the layer results + for name, metric in metrics.items(): + layer_results.add_metric(metric.aggregate(), name) + metric.clear() + + + # Add the layer to the results + results.add_layer(layer_results, layer_a) + + results.save('results_test') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/representations/store_representations.py b/representations/store_representations.py new file mode 100644 index 00000000..6e1f9c6f --- /dev/null +++ b/representations/store_representations.py @@ -0,0 +1,133 @@ +# WORK IN PROGRESS + +import click +import torch +import yaml + +from mergekit.config import MergeConfiguration + +import logging +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +import datasets + +import os + +logging.basicConfig(level=logging.INFO) + +# Set seed +torch.manual_seed(42) +np.random.seed(42) + +import torch +from typing import List + +import h5py +import torch +import random +import numpy as np + +def load_batch_from_hdf5(model_name, batch_idx): + with h5py.File('batches.h5', 'r') as h5file: + dataset_name = f'{model_name}/batch_{batch_idx}' + batch_data = h5file[dataset_name][:] + batch_tensor = torch.tensor(batch_data) + return batch_tensor + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tensor]: + """Get last non-padded tokens for each layer.""" + last_non_padded_hidden_states = [] + for layer in hidden_states: + batch_size, _, _ = layer.size() + batch_last_tokens = [] + for batch in range(batch_size): + last_non_pad_index = attention_mask[batch].nonzero(as_tuple=True)[0].max() + last_token = layer[batch, last_non_pad_index, :] + batch_last_tokens.append(last_token.unsqueeze(0)) + last_non_padded_hidden_states.append(torch.cat(batch_last_tokens, dim=0)) + return last_non_padded_hidden_states + + +@click.command() +@click.option('--model_path', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') +@click.option('--output_path', default="./representations/", help='folder to store the result in.') +@click.option('--dataset', default="arcee-ai/sec-data-mini", help='dataset to use.') +@click.option('--batch_size', default=8, help='batch size.') +@click.option('--max_length', default=1024, help='maximum length of the input.') +@click.option('--dataset_size', default=4000, help='size of the dataset.') +@click.option('--dataset_column', default="text", help='column of the dataset to use.') +@click.option('--dataset_subset', default="train", help='subset of the dataset to use.') +def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + + device = "cuda" if torch.cuda.is_available() \ + else "mps" if torch.backends.mps.is_available() \ + else "cpu" + + # if resource is a problem + quantization_config = BitsAndBytesConfig(load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16) + + dataset = datasets.load_dataset(dataset, split=dataset_subset) + if dataset_size: + dataset = dataset.select(range(dataset_size)) + + + model = AutoModelForCausalLM.from_pretrained(model_path, + device_map="auto", + quantization_config=quantization_config if device == "cuda" else None, + output_hidden_states=True) + + + tokenizer = AutoTokenizer.from_pretrained(model_path) + + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + model.eval() + + set_seed(42) + + dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) + + output_name = f'NEW_Representations_{model.name_or_path.replace("/","_")}_{dataset_subset}_{dataset_size}' + assert not os.path.exists(output_path+f'{output_name}.h5'), f'{output_name}.h5 already exists.' + + with h5py.File(f'{output_name}.h5', 'w') as h5file: + for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")): + inputs = tokenizer(batch, return_tensors="pt", padding="longest", max_length=max_length, truncation=True).to(device) + with torch.no_grad(): + outputs = model(**inputs) + attention_mask = inputs["attention_mask"] + hidden_states = outputs.hidden_states + last_non_padded_hidden_states = get_last_non_padded_tokens(hidden_states, attention_mask) + + # Remove the first element to account for the input layer not being considered a model hidden layer + # This adjustment is necessary for analyses focusing on the model's internal transformations + last_non_padded_hidden_states = last_non_padded_hidden_states[1:] + for layer, hidden_state in enumerate(last_non_padded_hidden_states): + layer_group = h5file.require_group(f'layer_{layer}') + file_name = f'batch_{batch_idx}.pt' + + layer_group.create_dataset(file_name, data=hidden_state.to('cpu'), compression="gzip") + + # Ensure that the length of last_non_padded_hidden_states matches the number of model hidden layers minus one + assert len(last_non_padded_hidden_states) == model.config.num_hidden_layers, "Length of last_non_padded_hidden_states \ + does not match expected number of hidden layers." + + +if __name__ == "__main__": + main() diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py new file mode 100644 index 00000000..0bf1e460 --- /dev/null +++ b/representations/visualise_representation_results.py @@ -0,0 +1,235 @@ +import torch +import h5py +import random +import numpy as np +import click +import yaml +import h5py + +from mergekit.config import MergeConfiguration +from mergekit.merge import MergeOptions +from mergekit.merge import run_merge +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler + +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer +from mergekit.metric_methods.metrics import cossim, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cossim_heatmap +from mergekit.architecture import WeightInfo + + +from typing import List, Tuple, Dict +import plotly.graph_objs as go +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np +from mergekit.graph import Task + + + +class CustomResultsHandler(ResultsHandler): + """ + Object to handle metrics results. Allows for easy plotting of metrics by layer and across layers. + + Input: + Use the load_metrics method to load the metrics into the handler. + metrics: List of tasks and their metrics. This is the output of the run_measure function in mergekit.measure. + + Attributes: + all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cossim_mean': 0.5, 'cossim_std': 0.1}} + metric_names: List of names of all statistics available. e.g. ['cossim_mean', 'cossim_std'] + layer_names: List of layer names. + + Methods: + load_metrics: Load the metrics into the handler. + # stats_at_layer: Get the metrics for a specific layer. + # info_at_layer: Get the weight info for a specific layer. + line_plot: Plot a line plot of the chosen stat across layers. + plotly_layer_histogram: Plot a histogram of the stat for a specific layer. + """ + def __init__(self):#, metrics: List[Tuple[Task, Layer]]): + self.results = Results() + # self.load_metrics(metrics) + + def load_metrics(self, metrics: List[Tuple[Task, Layer]]): + self.metric_names = [] + for task, metric in metrics: + if metric is not None: + self.results.add_layer(metric, name=task.weight_info.name) + self.metric_names.extend(list(metric.metrics.keys())) + self.layer_names = list(self.results.layers.keys()) + self.metric_names = list(set(self.metric_names)) + + def load_results(self, results: Results): + self.results = results + self.layer_names = list(self.results.layers.keys()) + self.metric_names = list(set([metric for layer in self.results.layers.values() for metric in layer.metrics.keys()])) + + def categorise_layers(self, layer_names): + # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config + categories = [] + for name in layer_names: + if 'Attention Block' in name: + categories.append('Attention Block') + elif 'mlp' in name: + categories.append('MLP') + elif 'layernorm' in name: + categories.append('LayerNorm') + else: + categories.append('Other') + return categories + + def plotly_line_plots(self, metric_name:str): + if metric_name not in self.metric_names: + print(f"Stat {metric_name} not found") + return [] + + layer_names = self.layer_names + means, stds, model_refs = self.results.get_lineplot_data(metric_name) + traces = [] + available_shapes = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] + + if len(model_refs) > 1: + unique_categories = [str(ref) for ref in model_refs] + layer_categories = [[str(model_refs[i])]*len(layer_names) for i in range(len(model_refs))] + else: + layer_categories = [self.categorise_layers(layer_names)] + unique_categories = list(set(layer_categories[0])) + for i, model_ref in enumerate(model_refs): + traces.extend(self._plotly_line_plot(layer_names, means[i], stds[i], layer_categories[i], unique_categories, shape=available_shapes[i%len(available_shapes)])) + + return traces, layer_names + + def _plotly_line_plot(self, x_values, means, stds, layer_categories, unique_categories, shape:str='circle', **kwargs): + """ + Plot the stat values across layers using Plotly. + + Args: + stat (str): The name of the stat to plot. + + Returns: + List[go.Scatter]: List of Plotly Scatter objects. + """ + + # Assign a unique color to each category + cmap = plt.get_cmap('Set1', len(unique_categories)) + colors = [mcolors.to_hex(cmap(i)) for i in range(len(unique_categories))] + + category_styles = {cat: colors[i % len(colors)] for i, cat in enumerate(unique_categories)} + + traces = [] + + for category in unique_categories: + y_category = [means[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + std_category = [stds[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + if all([y is None for y in y_category]): + continue + + traces.append(go.Scatter( + x=x_values, + y=y_category, + error_y=dict( + type='data', + array=std_category, + visible=True + ), + mode='markers', + name=category, + marker=dict(color=category_styles[category]), + marker_symbol=shape + )) + return traces + + def plotly_layer_heatmap(self, layer_name:str, metric_name:str): + """ + Plot the stat values as a heatmap. + + Args: + layer_name (str): The name of the layer. + metric_name (str): The name of the stat to plot. + Returns: + go.Heatmap: Plotly Heatmap object. + """ + metrics_list = self.results.layers[layer_name].metrics[metric_name] + if len(metrics_list) > 1: + raise Warning(f"Multiple heatmaps found for {metric_name} at layer {layer_name}. Using the first one.") + + heatmap = self.results.layers[layer_name].metrics[metric_name][0].heatmap.data + + return [go.Heatmap( + z=heatmap, + colorscale='RdBu' + )] + + def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): + """ + Set the attributes of the plot. + + Args: + ax: The matplotlib Axes object. + stat (str): The name of the stat. + **kwargs: Additional keyword arguments for plot attributes. + """ + # Defaults + ax.set_ylabel(kwargs.get('ylabel', stat)) + ax.set_xticks(np.arange(len(self.layer_names))) + ax.set_xticklabels(self.layer_names, rotation=45) + ax.set_title(kwargs.get('title', f'{stat.replace("_", " ").title()}')) + + # Set additional attributes + for kwarg in ax_kwargs: + if kwarg in kwargs: + getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) + + def plotly_layer_histogram(self, layer_name: str, metric_name: str): + metric_list = self.results.layers[layer_name].metrics[metric_name] + colors = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] # (X) + + traces = [] + for i, metric in enumerate(metric_list): + hist = metric.histogram + count, edges, widths = hist.count, hist.edges, hist.widths + traces.append(go.Bar( + x=edges[:-1], + y=count, + width=widths, + marker=dict( + color=colors[i], + opacity=0.75, + line=dict( + color='black', + width=1 + ) + ), + name=str(metric.model_ref) + )) + return traces + + def layer_plot_options(self, layer_name: str): + layer = self.results.layers[layer_name] + + return [ + {"label": f"{metric.title()} Histogram", "value": [metric, 'histogram']} + for metric in layer.metrics_with_attribute('histogram') + ] + [ + {"label": f"{metric.title()} Heatmap", "value": [metric, 'heatmap']} + for metric in layer.metrics_with_attribute('heatmap') + ] + + +@click.command() +@click.option('--results_path', + default="./representation_results_test.pkl", + help="path to load the results from.") +def main(results_path): + results = Results() + print('warning: results_path is hardcoded in main()') + results_path = '/Users/elliotstein/Documents/Arcee/mergekit/representations/results_test.pkl' + results = results.load(results_path) + + handler = CustomResultsHandler() + handler.load_results(results) + + app = create_app(results_handler=handler) + app.run_server() + +if __name__ == '__main__': + main() \ No newline at end of file From fa6d0988df1a363581bc7897f748cddfdf8bf67a Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Mon, 8 Jul 2024 17:10:35 +0100 Subject: [PATCH 33/64] refactor variable name for consistency --- mergekit/metric_methods/all_metrics.py | 10 +++++----- mergekit/metric_methods/base.py | 5 ++++- mergekit/metric_methods/metrics.py | 16 ++++++++-------- mergekit/plot_tools/plot_tools.py | 8 +++++--- representations/representation_metrics.py | 4 ++-- representations/store_representations.py | 2 +- .../visualise_representation_results.py | 16 ++++++---------- 7 files changed, 31 insertions(+), 30 deletions(-) diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index 453cb39a..b78a25f7 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -22,7 +22,7 @@ from mergekit.metric_methods.base import MetricMethod, Layer import torch from typing import Dict, List, Any -from mergekit.metric_methods.metrics import cossim, smape, scale, mse, weight_magnitude, numerical_rank +from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): """Validate tensor shapes and count.""" @@ -105,7 +105,7 @@ def execute( layer_results = Layer(metrics={}, weight_info=self.weight_info) - layer_results.add_metric(cossim(weights, return_heatmap=False), name = 'cossim') + layer_results.add_metric(cosine_similarity(weights, return_heatmap=False), name = 'cosine_similarity') layer_results.add_metric(smape(weights), name = 'smape') layer_results.add_metric(scale(weights, return_heatmap=False), name = 'scale') layer_results.add_metric(mse(weights, return_heatmap=False), name = 'mse') @@ -163,10 +163,10 @@ def execute( weight_info=self.weight_info) - layer_results.add_metric(cossim([model_0_heads.view(model_0_heads.shape[0], -1), + layer_results.add_metric(cosine_similarity([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)], return_heatmap=True), - name = 'cossim') + name = 'cosine_similarity') layer_results.add_metric(smape([model_0_heads.view(model_0_heads.shape[0], -1), model_1_heads.view(model_1_heads.shape[0], -1)]), name = 'smape') @@ -220,7 +220,7 @@ def execute( layer_results = Layer(metrics={}, weight_info=self.weight_info) - layer_results.add_metric(cossim([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'cossim') + layer_results.add_metric(cosine_similarity([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'cosine_similarity') layer_results.add_metric(smape([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)]), name = 'smape') layer_results.add_metric(scale([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'scale') layer_results.add_metric(mse([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'mse') diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 315e0b18..a166a31d 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -119,6 +119,7 @@ class Results: # Class to store the statistics for each layer def __init__(self): self.layers: Dict[str, Layer] = {} + self.others: Dict[str, Metric] = {} def add_layer(self, layer: Layer, name: str): if name not in self.layers.keys(): @@ -193,4 +194,6 @@ def save(self, path: str): def load(self, path: str): with open(path, 'rb') as f: - return pickle.load(f) \ No newline at end of file + results = pickle.load(f) + assert isinstance(results, Results), "Loaded object is not a Results object" + return results \ No newline at end of file diff --git a/mergekit/metric_methods/metrics.py b/mergekit/metric_methods/metrics.py index c19b7c5c..ad3b2060 100644 --- a/mergekit/metric_methods/metrics.py +++ b/mergekit/metric_methods/metrics.py @@ -27,7 +27,7 @@ def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: bin_widths = np.diff(bin_edges) return bin_counts, bin_edges, bin_widths -def cossim_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: +def cosine_similarity_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # Normalize the rows of both matrices A_norm = A / A.norm(dim=1, keepdim=True) B_norm = B / B.norm(dim=1, keepdim=True) @@ -56,22 +56,22 @@ def smape( mean_std=MeanStd(mean=smape.mean().item(), std=smape.std().item()) ) -def cossim( +def cosine_similarity( tensors: List[torch.Tensor], return_heatmap=False, **_kwargs ) -> Metric: """Cosine similarity""" - cossim = F.cosine_similarity(tensors[0], tensors[1], dim=1) + cosine_similarity = F.cosine_similarity(tensors[0], tensors[1], dim=1) if return_heatmap: - heatmap = cossim_heatmap(tensors[0], tensors[1]) + heatmap = cosine_similarity_heatmap(tensors[0], tensors[1]) - assert torch.isclose(cossim, cossim, atol=1e-6).all(), "NaNs in cosine similarity" - assert torch.isclose(cossim, cossim_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-2).all(), "Diagonal elements of cosine similarity matrix do not match" + assert torch.isclose(cosine_similarity, cosine_similarity, atol=1e-6).all(), "NaNs in cosine similarity" + assert torch.isclose(cosine_similarity, cosine_similarity_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-2).all(), "Diagonal elements of cosine similarity matrix do not match" - hist_info = compute_histogram(cossim, 100) + hist_info = compute_histogram(cosine_similarity, 100) return Metric( histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), - mean_std=MeanStd(mean=cossim.mean().item(), std=cossim.std().item()), + mean_std=MeanStd(mean=cosine_similarity.mean().item(), std=cosine_similarity.std().item()), heatmap=Heatmap(data=heatmap) if return_heatmap else None ) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 047b5fdc..da965b50 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -11,6 +11,8 @@ from mergekit.metric_methods.all_metrics import Layer from mergekit.metric_methods.base import Results +global_colours_list = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] + class ResultsHandler: """ Object to handle metrics results. Allows for easy plotting of metrics by layer and across layers. @@ -20,8 +22,8 @@ class ResultsHandler: metrics: List of tasks and their metrics. This is the output of the run_measure function in mergekit.measure. Attributes: - all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cossim_mean': 0.5, 'cossim_std': 0.1}} - metric_names: List of names of all statistics available. e.g. ['cossim_mean', 'cossim_std'] + all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cosine_similarity_mean': 0.5, 'cosine_similarity_std': 0.1}} + metric_names: List of names of all statistics available. e.g. ['cosine_similarity_mean', 'cosine_similarity_std'] layer_names: List of layer names. Methods: @@ -218,7 +220,7 @@ def create_line_plot_section(results_handler): id='line-plot-dropdown', options=[{'label': metric_name.replace('_', ' ').title(), 'value': metric_name} for metric_name in results_handler.metric_names], - value='cossim', + value='cosine_similarity', style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} ), dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index f80e9894..58ff4975 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -13,7 +13,7 @@ from mergekit.plot_tools.plot_tools import create_app, ResultsHandler from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer -from mergekit.metric_methods.metrics import cossim, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cossim_heatmap +from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap from mergekit.architecture import WeightInfo from tqdm import tqdm @@ -32,7 +32,7 @@ def aggregate(self): def clear(self): self.__init__() -class CosineSimilarity(MetricAggregator): +class Cosine_Similarity(MetricAggregator): def __init__(self, device="cpu"): self.cosine_similarities = torch.tensor([]).to(device) diff --git a/representations/store_representations.py b/representations/store_representations.py index 6e1f9c6f..f0d3c16f 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -103,7 +103,7 @@ def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) - output_name = f'NEW_Representations_{model.name_or_path.replace("/","_")}_{dataset_subset}_{dataset_size}' + output_name = f'Representations_{model.name_or_path.replace("/","_")}_{dataset_subset}_{dataset_size}' assert not os.path.exists(output_path+f'{output_name}.h5'), f'{output_name}.h5 already exists.' with h5py.File(f'{output_name}.h5', 'w') as h5file: diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py index 0bf1e460..2bc08b08 100644 --- a/representations/visualise_representation_results.py +++ b/representations/visualise_representation_results.py @@ -9,10 +9,10 @@ from mergekit.config import MergeConfiguration from mergekit.merge import MergeOptions from mergekit.merge import run_merge -from mergekit.plot_tools.plot_tools import create_app, ResultsHandler +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler, global_colours_list from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer -from mergekit.metric_methods.metrics import cossim, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cossim_heatmap +from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap from mergekit.architecture import WeightInfo @@ -23,8 +23,6 @@ import numpy as np from mergekit.graph import Task - - class CustomResultsHandler(ResultsHandler): """ Object to handle metrics results. Allows for easy plotting of metrics by layer and across layers. @@ -34,8 +32,8 @@ class CustomResultsHandler(ResultsHandler): metrics: List of tasks and their metrics. This is the output of the run_measure function in mergekit.measure. Attributes: - all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cossim_mean': 0.5, 'cossim_std': 0.1}} - metric_names: List of names of all statistics available. e.g. ['cossim_mean', 'cossim_std'] + all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cosine_similarity_mean': 0.5, 'cosine_similarity_std': 0.1}} + metric_names: List of names of all statistics available. e.g. ['cosine_similarity_mean', 'cosine_similarity_std'] layer_names: List of layer names. Methods: @@ -80,7 +78,7 @@ def categorise_layers(self, layer_names): def plotly_line_plots(self, metric_name:str): if metric_name not in self.metric_names: print(f"Stat {metric_name} not found") - return [] + return [], [] layer_names = self.layer_names means, stds, model_refs = self.results.get_lineplot_data(metric_name) @@ -181,7 +179,6 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): def plotly_layer_histogram(self, layer_name: str, metric_name: str): metric_list = self.results.layers[layer_name].metrics[metric_name] - colors = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] # (X) traces = [] for i, metric in enumerate(metric_list): @@ -192,7 +189,7 @@ def plotly_layer_histogram(self, layer_name: str, metric_name: str): y=count, width=widths, marker=dict( - color=colors[i], + color=global_colours_list[i], opacity=0.75, line=dict( color='black', @@ -214,7 +211,6 @@ def layer_plot_options(self, layer_name: str): for metric in layer.metrics_with_attribute('heatmap') ] - @click.command() @click.option('--results_path', default="./representation_results_test.pkl", From 4949458d108d741487a6ab68dd4bccab8e263940 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Mon, 8 Jul 2024 17:11:39 +0100 Subject: [PATCH 34/64] Plot heatmaps stored in results.others --- mergekit/plot_tools/plot_tools.py | 46 ++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index da965b50..796dc055 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -203,7 +203,8 @@ def create_app(results_handler): app.layout = html.Div([ create_header(), create_line_plot_section(results_handler), - create_layer_metrics_section() + create_layer_metrics_section(), + create_heatmap_section(results_handler) ]) register_callbacks(app, results_handler) @@ -238,6 +239,24 @@ def create_layer_metrics_section(): dcc.Graph(id='layer-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}) ], className='container-fluid') +def create_heatmap_section(results_handler): + if hasattr(results_handler.results, 'others') and isinstance(results_handler.results.others, dict): + heatmap_sections = [] + for i, (key, value) in enumerate(results_handler.results.others.items()): + model_name = key.split('Representations')[1] + model_name = model_name.split('.')[0].replace('_', ' ').replace('-',' ').title() + metric = key.split('||')[-1].replace('_', ' ').replace('-',' ').title() + + title = f'{model_name} - {metric}' + + heatmap_sections.append(html.Div([ + html.H3(f'Heatmap: {title}', style={'textAlign': 'center'}), + dcc.Graph(id=f'heatmap-plot-{i}', style={'width': '30%', 'height': '30%'}) + ], className='container-fluid')) + return html.Div(heatmap_sections) + else: + return html.Div() + def default_option(options, current_value): if not options: return None @@ -337,6 +356,31 @@ def update_line_plot(selected_metric): yaxis=dict(title=selected_metric.replace('_', ' ').title()) ) return fig + # Dynamically create callbacks for each heatmap plot + if hasattr(results_handler.results, 'others') and isinstance(results_handler.results.others, dict): + for i, (key, value) in enumerate(results_handler.results.others.items()): + if isinstance(value.data, (list, np.ndarray)): # Assuming heatmap data is in array-like format + @app.callback( + Output(f'heatmap-plot-{i}', 'figure'), + Input(f'heatmap-plot-{i}', 'id') # Dummy input to trigger the callback on load + ) + def update_heatmap_plot(_key=key): + key = list(results_handler.results.others.keys())[int(_key.split('-')[-1])] + heatmap_data = results_handler.results.others[key].data + fig = go.Figure(data=go.Heatmap( + z=heatmap_data, + colorscale='Viridis', # Using Viridis colormap + zmin=np.nanmin(heatmap_data), # Set the scale min to the min data value + zmax=np.nanmax(heatmap_data), # Set the scale max to the max data value + colorbar=dict(title='Scale') # Customize the color bar + )) + + fig.update_layout( + title=f"Heatmap: {_key}", + xaxis_title="X Axis", + yaxis_title="Y Axis" + ) + return fig def create_figure(traces, title, xaxis_title, yaxis_title): fig = go.Figure() From bd62ed9f3a52ecc42b13c8c65a0a4510a4136311 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Mon, 8 Jul 2024 17:16:09 +0100 Subject: [PATCH 35/64] generalised results handler to load from metrics list or ready-made Results object --- mergekit/plot_tools/plot_tools.py | 13 +- mergekit/scripts/run_metrics.py | 3 +- .../visualise_representation_results.py | 190 +----------------- 3 files changed, 11 insertions(+), 195 deletions(-) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 796dc055..3f9ce101 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -33,9 +33,8 @@ class ResultsHandler: line_plot: Plot a line plot of the chosen stat across layers. plotly_layer_histogram: Plot a histogram of the stat for a specific layer. """ - def __init__(self, metrics: List[Tuple[Task, Layer]]): + def __init__(self): self.results = Results() - self.load_metrics(metrics) def load_metrics(self, metrics: List[Tuple[Task, Layer]]): self.metric_names = [] @@ -45,6 +44,11 @@ def load_metrics(self, metrics: List[Tuple[Task, Layer]]): self.metric_names.extend(list(metric.metrics.keys())) self.layer_names = list(self.results.layers.keys()) self.metric_names = list(set(self.metric_names)) + + def load_results(self, results: Results): + self.results = results + self.layer_names = list(self.results.layers.keys()) + self.metric_names = list(set([metric for layer in self.results.layers.values() for metric in layer.metrics.keys()])) def categorise_layers(self, layer_names): # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config @@ -63,7 +67,7 @@ def categorise_layers(self, layer_names): def plotly_line_plots(self, metric_name:str): if metric_name not in self.metric_names: print(f"Stat {metric_name} not found") - return [] + return [], [] layer_names = self.layer_names means, stds, model_refs = self.results.get_lineplot_data(metric_name) @@ -164,7 +168,6 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): def plotly_layer_histogram(self, layer_name: str, metric_name: str): metric_list = self.results.layers[layer_name].metrics[metric_name] - colors = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] # (X) traces = [] for i, metric in enumerate(metric_list): @@ -175,7 +178,7 @@ def plotly_layer_histogram(self, layer_name: str, metric_name: str): y=count, width=widths, marker=dict( - color=colors[i], + color=global_colours_list[i], opacity=0.75, line=dict( color='black', diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py index d5452808..624af005 100644 --- a/mergekit/scripts/run_metrics.py +++ b/mergekit/scripts/run_metrics.py @@ -28,7 +28,8 @@ def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory) ), ) - handler = ResultsHandler(metrics_results) + handler = ResultsHandler() + handler.load_metrics(metrics_results) app = create_app(results_handler=handler) app.run_server() diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py index 2bc08b08..578a7e9d 100644 --- a/representations/visualise_representation_results.py +++ b/representations/visualise_representation_results.py @@ -23,194 +23,6 @@ import numpy as np from mergekit.graph import Task -class CustomResultsHandler(ResultsHandler): - """ - Object to handle metrics results. Allows for easy plotting of metrics by layer and across layers. - - Input: - Use the load_metrics method to load the metrics into the handler. - metrics: List of tasks and their metrics. This is the output of the run_measure function in mergekit.measure. - - Attributes: - all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cosine_similarity_mean': 0.5, 'cosine_similarity_std': 0.1}} - metric_names: List of names of all statistics available. e.g. ['cosine_similarity_mean', 'cosine_similarity_std'] - layer_names: List of layer names. - - Methods: - load_metrics: Load the metrics into the handler. - # stats_at_layer: Get the metrics for a specific layer. - # info_at_layer: Get the weight info for a specific layer. - line_plot: Plot a line plot of the chosen stat across layers. - plotly_layer_histogram: Plot a histogram of the stat for a specific layer. - """ - def __init__(self):#, metrics: List[Tuple[Task, Layer]]): - self.results = Results() - # self.load_metrics(metrics) - - def load_metrics(self, metrics: List[Tuple[Task, Layer]]): - self.metric_names = [] - for task, metric in metrics: - if metric is not None: - self.results.add_layer(metric, name=task.weight_info.name) - self.metric_names.extend(list(metric.metrics.keys())) - self.layer_names = list(self.results.layers.keys()) - self.metric_names = list(set(self.metric_names)) - - def load_results(self, results: Results): - self.results = results - self.layer_names = list(self.results.layers.keys()) - self.metric_names = list(set([metric for layer in self.results.layers.values() for metric in layer.metrics.keys()])) - - def categorise_layers(self, layer_names): - # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config - categories = [] - for name in layer_names: - if 'Attention Block' in name: - categories.append('Attention Block') - elif 'mlp' in name: - categories.append('MLP') - elif 'layernorm' in name: - categories.append('LayerNorm') - else: - categories.append('Other') - return categories - - def plotly_line_plots(self, metric_name:str): - if metric_name not in self.metric_names: - print(f"Stat {metric_name} not found") - return [], [] - - layer_names = self.layer_names - means, stds, model_refs = self.results.get_lineplot_data(metric_name) - traces = [] - available_shapes = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] - - if len(model_refs) > 1: - unique_categories = [str(ref) for ref in model_refs] - layer_categories = [[str(model_refs[i])]*len(layer_names) for i in range(len(model_refs))] - else: - layer_categories = [self.categorise_layers(layer_names)] - unique_categories = list(set(layer_categories[0])) - for i, model_ref in enumerate(model_refs): - traces.extend(self._plotly_line_plot(layer_names, means[i], stds[i], layer_categories[i], unique_categories, shape=available_shapes[i%len(available_shapes)])) - - return traces, layer_names - - def _plotly_line_plot(self, x_values, means, stds, layer_categories, unique_categories, shape:str='circle', **kwargs): - """ - Plot the stat values across layers using Plotly. - - Args: - stat (str): The name of the stat to plot. - - Returns: - List[go.Scatter]: List of Plotly Scatter objects. - """ - - # Assign a unique color to each category - cmap = plt.get_cmap('Set1', len(unique_categories)) - colors = [mcolors.to_hex(cmap(i)) for i in range(len(unique_categories))] - - category_styles = {cat: colors[i % len(colors)] for i, cat in enumerate(unique_categories)} - - traces = [] - - for category in unique_categories: - y_category = [means[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] - std_category = [stds[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] - if all([y is None for y in y_category]): - continue - - traces.append(go.Scatter( - x=x_values, - y=y_category, - error_y=dict( - type='data', - array=std_category, - visible=True - ), - mode='markers', - name=category, - marker=dict(color=category_styles[category]), - marker_symbol=shape - )) - return traces - - def plotly_layer_heatmap(self, layer_name:str, metric_name:str): - """ - Plot the stat values as a heatmap. - - Args: - layer_name (str): The name of the layer. - metric_name (str): The name of the stat to plot. - Returns: - go.Heatmap: Plotly Heatmap object. - """ - metrics_list = self.results.layers[layer_name].metrics[metric_name] - if len(metrics_list) > 1: - raise Warning(f"Multiple heatmaps found for {metric_name} at layer {layer_name}. Using the first one.") - - heatmap = self.results.layers[layer_name].metrics[metric_name][0].heatmap.data - - return [go.Heatmap( - z=heatmap, - colorscale='RdBu' - )] - - def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): - """ - Set the attributes of the plot. - - Args: - ax: The matplotlib Axes object. - stat (str): The name of the stat. - **kwargs: Additional keyword arguments for plot attributes. - """ - # Defaults - ax.set_ylabel(kwargs.get('ylabel', stat)) - ax.set_xticks(np.arange(len(self.layer_names))) - ax.set_xticklabels(self.layer_names, rotation=45) - ax.set_title(kwargs.get('title', f'{stat.replace("_", " ").title()}')) - - # Set additional attributes - for kwarg in ax_kwargs: - if kwarg in kwargs: - getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) - - def plotly_layer_histogram(self, layer_name: str, metric_name: str): - metric_list = self.results.layers[layer_name].metrics[metric_name] - - traces = [] - for i, metric in enumerate(metric_list): - hist = metric.histogram - count, edges, widths = hist.count, hist.edges, hist.widths - traces.append(go.Bar( - x=edges[:-1], - y=count, - width=widths, - marker=dict( - color=global_colours_list[i], - opacity=0.75, - line=dict( - color='black', - width=1 - ) - ), - name=str(metric.model_ref) - )) - return traces - - def layer_plot_options(self, layer_name: str): - layer = self.results.layers[layer_name] - - return [ - {"label": f"{metric.title()} Histogram", "value": [metric, 'histogram']} - for metric in layer.metrics_with_attribute('histogram') - ] + [ - {"label": f"{metric.title()} Heatmap", "value": [metric, 'heatmap']} - for metric in layer.metrics_with_attribute('heatmap') - ] - @click.command() @click.option('--results_path', default="./representation_results_test.pkl", @@ -221,7 +33,7 @@ def main(results_path): results_path = '/Users/elliotstein/Documents/Arcee/mergekit/representations/results_test.pkl' results = results.load(results_path) - handler = CustomResultsHandler() + handler = ResultsHandler() handler.load_results(results) app = create_app(results_handler=handler) From 9c875142959b0487d8111ff7c3d3275a6dcd88be Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 10:22:10 +0100 Subject: [PATCH 36/64] abstracted and add skip block analysis --- representations/representation_metrics.py | 213 +++++++++++++++++----- 1 file changed, 169 insertions(+), 44 deletions(-) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 58ff4975..927f6a97 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -1,4 +1,3 @@ -#%% import torch import h5py import random @@ -15,6 +14,7 @@ from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap from mergekit.architecture import WeightInfo +from typing import List from tqdm import tqdm @@ -63,8 +63,6 @@ def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): batch_square_errors = torch.square(batch_a - batch_b).flatten() self.square_errors = torch.cat((self.square_errors, batch_square_errors), dim=0) - # CHECK DIMENSIONALITY (X) - def aggregate(self): hist = compute_histogram(self.square_errors, 100) out = Metric( @@ -81,58 +79,185 @@ def aggregate(self): self.clear() return out +class LayerByIndex: + def __init__(self, reps_path): + self.reps_path = reps_path + self.representations = None + self.layers = None + self.iter_index = 0 + + def __enter__(self): + self.representations = h5py.File(self.reps_path, 'r') + self.layers = list(self.representations.keys()) + return self + + def __exit__(self, *args, **kwargs): + if self.representations: + self.representations.close() + + def __getitem__(self, idx): + return self.representations[self.layers[idx]] + + def __len__(self): + return len(self.layers) + + def __iter__(self): + self.iter_index = 0 + return self + + def __next__(self): + if self.iter_index < len(self.layers): + layer = self.representations[self.layers[self.iter_index]] + self.iter_index += 1 + return layer + else: + raise StopIteration + + def batches_in_layer(self, idx): + return len(self.representations[self.layers[idx]]) + +def compare_representations(representations_a: h5py.File, representations_b: h5py.File, metrics_classes, device: str): + results=Results() + + # Compare corresponding layers from both models + for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Comparing Represenations at layer', total=len(representations_a), leave=False): + metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] + + layer_results = Layer(WeightInfo(name=layer_a)) + if layer_a != layer_b: + raise ValueError(f'Layer mismatch: {layer_a} != {layer_b}') + + # Load the representations + layer_representations_a = representations_a[layer_a] + layer_representations_b = representations_b[layer_b] + + for batch_a, batch_b in tqdm(zip(layer_representations_a, layer_representations_b), desc='Batch', total=len(layer_representations_a), leave=False): + batch_a = torch.tensor(layer_representations_a[batch_a][:], device=device) + batch_b = torch.tensor(layer_representations_b[batch_b][:], device=device) + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch_a, batch_b) + + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower().lower()) + metric.clear() + + + # Add the layer to the results + results.add_layer(layer_results, layer_a) + + return results + +def compute_skip_block_metrics(reps_path:str, skip_layers:int, metric_classes:List[MetricAggregator], device:str='cpu'): + results = Results() + with LayerByIndex(reps_path) as reps: + for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {skip_layers}-block, Block Start at Layer', total=len(reps)-skip_layers, leave=False): + # Create metrics + metrics = [metric_class(device=device) for metric_class in metric_classes] + + if idx + skip_layers >= len(reps): + continue + block_end = reps[idx + skip_layers] + + # Each metric processes every batch + for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', total=len(block_start), leave=False): + batch_0 = torch.tensor(block_start[batch_0][:]).to(device) + batch_1 = torch.tensor(block_end[batch_1][:]).to(device) + + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + # Aggregate metrics and add to results + layer_results = Layer(WeightInfo(name=f"{idx}")) + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower().lower()) + results.add_layer(layer_results, f"Layer {idx}") + + # Clear memory from metrics + for metric in metrics: + metric.clear() + return results + +def resultslist_to_heatmap(all_results, metric_names:List[str]) -> dict: + rows = len(all_results) + cols = max([len(result.layers) for result in all_results]) + heatmaps = {} + for metric_name in metric_names: + heatmap = np.full((rows, cols), np.nan) + + + for i, result in enumerate(all_results): + # row = len(all_results) - (i+1) + for j, layer in enumerate(result.layers): + heatmap[i, j] = result.layers[layer].metrics[metric_name][0].mean_std.mean + heatmaps[metric_name] = Heatmap(data=heatmap) + return heatmaps + +metrics_table = { + 'cosine_similarity': Cosine_Similarity, + 'mse': MSE +} + @click.command() -@click.option('--reps_a_path', - default="NEW_Representations_BEE-spoke-data_smol_llama-220M-GQA_train_4000.h5", - help="path to load first set of representations from.") -@click.option('--reps_b_path', - default="NEW_Representations_BEE-spoke-data_smol_llama-220M-openhermes_train_4000.h5", - help="path to load second set of representations from.") -def main(reps_a_path, reps_b_path): +@click.option('--config_yml', + default="./config.yml", + help='merge configuration file.') +def main(config_yml): + with open(config_yml, "r", encoding="utf-8") as fp: + config_yml = yaml.safe_load(fp) + + model_paths = config_yml['representation_paths'] + metrics_toggle = config_yml['metrics'] + skip_layers = config_yml['block_analysis_parameters']['skip_layers'] + results = Results() + use_metrics = {} + for metric in metrics_table: + if metrics_toggle[metric]: + use_metrics[metric] = metrics_table[metric] + device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - reps_a_path = Path(__file__).parent / reps_a_path - reps_b_path = Path(__file__).parent / reps_b_path + for path in model_paths: + assert Path(path).exists(), f"File not found: {path}" - assert reps_a_path.exists(), f"File not found: {reps_a_path}" - assert reps_b_path.exists(), f"File not found: {reps_b_path}" - - with h5py.File(reps_a_path, 'r') as representations_a, \ - h5py.File(reps_b_path, 'r') as representations_b: - - for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Layers', total=len(representations_a)): - metrics = { - 'Cosine Similarity' : CosineSimilarity(device=device), - 'MSE' : MSE(device=device) - } - - layer_results = Layer(WeightInfo(name=layer_a)) - if layer_a != layer_b: - raise ValueError(f'Layer mismatch: {layer_a} != {layer_b}') + if config_yml['compare_between_models']: + if len(model_paths) != 2: + raise ValueError("Expected 2 model paths for comparison") + + with h5py.File(model_paths[0], 'r') as representations_a, \ + h5py.File(model_paths[1], 'r') as representations_b: + + # Compare corresponding layer representations + results = compare_representations(representations_a, + representations_b, + metrics_classes=use_metrics, + device=device) - # Load the representations - layer_representations_a = representations_a[layer_a] - layer_representations_b = representations_b[layer_b] - - for batch_a, batch_b in tqdm(zip(layer_representations_a, layer_representations_b), desc='Batches', total=len(layer_representations_a), leave=False): - batch_a = torch.tensor(layer_representations_a[batch_a][:], device=device) - batch_b = torch.tensor(layer_representations_b[batch_b][:], device=device) - # Calculate the metrics for each batch - for _, metric in metrics.items(): - metric.process_batch(batch_a, batch_b) + results.save('results_compare') + + if config_yml['block_analysis']: + # Analyse individual layer representations + for reps_path in model_paths: - # Aggregate over the batches and add to the layer results - for name, metric in metrics.items(): - layer_results.add_metric(metric.aggregate(), name) - metric.clear() + metric_classes = list(use_metrics.values()) + + all_results = [] + for skip_layer in tqdm(skip_layers, desc='Skip Layers', total=len(skip_layers)): + all_results.append( + compute_skip_block_metrics(reps_path, skip_layer, metric_classes, device=device) + ) + + heatmaps = resultslist_to_heatmap(all_results, metric_names=[metric.__name__.lower() for metric in metric_classes]) + for metric_name, heatmap in heatmaps.items(): + results.others[reps_path + '||' + metric_name] = heatmap - # Add the layer to the results - results.add_layer(layer_results, layer_a) + results.save('results_block_analysis') - results.save('results_test') + results.save('results') if __name__ == '__main__': main() \ No newline at end of file From 4ce67b0cff548cfd81b47b59c69e5fb5136f418b Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 10:59:41 +0100 Subject: [PATCH 37/64] refactor and restructure --- representations/representation_metrics.py | 269 +++++++++------------- 1 file changed, 112 insertions(+), 157 deletions(-) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 927f6a97..6c2ca737 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -1,263 +1,218 @@ import torch import h5py -import random import numpy as np import click import yaml -import h5py - -from mergekit.config import MergeConfiguration -from mergekit.merge import MergeOptions -from mergekit.merge import run_merge -from mergekit.plot_tools.plot_tools import create_app, ResultsHandler +from pathlib import Path +from typing import List, Dict, Any, Optional +from tqdm import tqdm +import torch.nn.functional as F from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap -from mergekit.architecture import WeightInfo -from typing import List - -from tqdm import tqdm -from pathlib import Path -import torch.nn.functional as F +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.merge_methods.base import MergeMethod -class MetricAggregator(): - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): - pass +class MetricAggregator: + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + raise NotImplementedError - def aggregate(self): - pass + def aggregate(self) -> Metric: + raise NotImplementedError - def clear(self): - self.__init__() + def clear(self) -> None: + raise NotImplementedError class Cosine_Similarity(MetricAggregator): - def __init__(self, device="cpu"): - self.cosine_similarities = torch.tensor([]).to(device) + def __init__(self, device: str = "cpu"): + self.device = device + self.cosine_similarities = torch.tensor([], device=self.device) - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: batch_similarities = F.cosine_similarity(batch_a, batch_b, dim=1) - self.cosine_similarities = torch.cat((self.cosine_similarities, batch_similarities), dim=0) + self.cosine_similarities = torch.cat((self.cosine_similarities, batch_similarities)) - def aggregate(self): + def aggregate(self) -> Metric: hist = compute_histogram(self.cosine_similarities, 100) - return Metric( - mean_std=MeanStd( + mean_std=MeanStd( mean=self.cosine_similarities.mean().item(), - std=self.cosine_similarities.std().item()), - histogram=Histogram( + std=self.cosine_similarities.std().item() + ) + histogram=Histogram( count=hist[0], edges=hist[1], widths=hist[2] ) + self.clear() + return Metric( + histogram=histogram, + mean_std=mean_std ) -class MSE(MetricAggregator): - def __init__(self, device="cpu"): - self.square_errors = torch.tensor([]).to(device) + def clear(self) -> None: + self.cosine_similarities = torch.tensor([], device=self.device) +class MSE(MetricAggregator): + def __init__(self, device: str = "cpu"): + self.device = device + self.square_errors = torch.tensor([], device=self.device) - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor): - assert batch_a.size(1) == batch_b.size(1) + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: batch_square_errors = torch.square(batch_a - batch_b).flatten() - self.square_errors = torch.cat((self.square_errors, batch_square_errors), dim=0) + self.square_errors = torch.cat((self.square_errors, batch_square_errors)) - def aggregate(self): + def aggregate(self) -> Metric: hist = compute_histogram(self.square_errors, 100) - out = Metric( - mean_std=MeanStd( + mean_std=MeanStd( mean=self.square_errors.mean().item(), std=self.square_errors.std().item() - ), - histogram=Histogram( + ) + histogram=Histogram( count=hist[0], edges=hist[1], widths=hist[2] ) - ) self.clear() - return out + return Metric( + histogram=histogram, + mean_std=mean_std + ) + + def clear(self) -> None: + self.square_errors = torch.tensor([], device=self.device) class LayerByIndex: - def __init__(self, reps_path): + def __init__(self, reps_path: str): self.reps_path = reps_path self.representations = None self.layers = None - self.iter_index = 0 - + def __enter__(self): self.representations = h5py.File(self.reps_path, 'r') self.layers = list(self.representations.keys()) return self - + def __exit__(self, *args, **kwargs): if self.representations: self.representations.close() - - def __getitem__(self, idx): + + def __getitem__(self, idx: int): return self.representations[self.layers[idx]] - - def __len__(self): + + def __len__(self) -> int: return len(self.layers) - + def __iter__(self): - self.iter_index = 0 - return self - - def __next__(self): - if self.iter_index < len(self.layers): - layer = self.representations[self.layers[self.iter_index]] - self.iter_index += 1 - return layer - else: - raise StopIteration - - def batches_in_layer(self, idx): - return len(self.representations[self.layers[idx]]) - -def compare_representations(representations_a: h5py.File, representations_b: h5py.File, metrics_classes, device: str): - results=Results() - - # Compare corresponding layers from both models - for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Comparing Represenations at layer', total=len(representations_a), leave=False): - metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] + return iter(self.representations[layer] for layer in self.layers) - layer_results = Layer(WeightInfo(name=layer_a)) +def compare_representations(representations_a: h5py.File, representations_b: h5py.File, + metrics_classes: Dict[str, MetricAggregator], device: str) -> Dict[str, Any]: + results = Results() + + for layer_a, layer_b in tqdm(zip(representations_a, representations_b), + desc='Comparing Representations at layer', + total=len(representations_a)): if layer_a != layer_b: raise ValueError(f'Layer mismatch: {layer_a} != {layer_b}') - - # Load the representations - layer_representations_a = representations_a[layer_a] - layer_representations_b = representations_b[layer_b] - - for batch_a, batch_b in tqdm(zip(layer_representations_a, layer_representations_b), desc='Batch', total=len(layer_representations_a), leave=False): - batch_a = torch.tensor(layer_representations_a[batch_a][:], device=device) - batch_b = torch.tensor(layer_representations_b[batch_b][:], device=device) + + metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] + + for batch_a, batch_b in tqdm(zip(representations_a[layer_a], representations_b[layer_b]), + desc='Batch', total=len(representations_a[layer_a]), leave=False): + batch_a = torch.tensor(representations_a[layer_a][batch_a][:], device=device) + batch_b = torch.tensor(representations_b[layer_b][batch_b][:], device=device) # Calculate the metrics for each batch for metric in metrics: metric.process_batch(batch_a, batch_b) - + + layer_results = Layer(WeightInfo(name=layer_a)) # Aggregate over the batches and add to the layer results for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower().lower()) + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) metric.clear() - - # Add the layer to the results results.add_layer(layer_results, layer_a) - + return results -def compute_skip_block_metrics(reps_path:str, skip_layers:int, metric_classes:List[MetricAggregator], device:str='cpu'): +def compute_skip_block_metrics(reps_path: str, skip_layers: int, + metric_classes: List[MetricAggregator], device: str) -> Results: results = Results() with LayerByIndex(reps_path) as reps: - for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {skip_layers}-block, Block Start at Layer', total=len(reps)-skip_layers, leave=False): + for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {skip_layers}-block, Block Start at Layer', + total=len(reps) - skip_layers, leave=False): + if idx + skip_layers >= len(reps): + break + # Create metrics metrics = [metric_class(device=device) for metric_class in metric_classes] - - if idx + skip_layers >= len(reps): - continue block_end = reps[idx + skip_layers] - # Each metric processes every batch - for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', total=len(block_start), leave=False): + for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', + total=len(block_start), leave=False): batch_0 = torch.tensor(block_start[batch_0][:]).to(device) batch_1 = torch.tensor(block_end[batch_1][:]).to(device) - + for metric in metrics: metric.process_batch(batch_0, batch_1) # Aggregate metrics and add to results - layer_results = Layer(WeightInfo(name=f"{idx}")) + layer_results = Layer(WeightInfo(name=f"Layer {idx}")) for metric in metrics: layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower().lower()) - results.add_layer(layer_results, f"Layer {idx}") - - # Clear memory from metrics - for metric in metrics: - metric.clear() - return results - -def resultslist_to_heatmap(all_results, metric_names:List[str]) -> dict: - rows = len(all_results) - cols = max([len(result.layers) for result in all_results]) - heatmaps = {} - for metric_name in metric_names: - heatmap = np.full((rows, cols), np.nan) + results.add_layer(layer_results, f"Layer {idx}") - for i, result in enumerate(all_results): - # row = len(all_results) - (i+1) - for j, layer in enumerate(result.layers): - heatmap[i, j] = result.layers[layer].metrics[metric_name][0].mean_std.mean - heatmaps[metric_name] = Heatmap(data=heatmap) - return heatmaps + return results -metrics_table = { +METRICS_TABLE = { 'cosine_similarity': Cosine_Similarity, 'mse': MSE } @click.command() -@click.option('--config_yml', - default="./config.yml", - help='merge configuration file.') -def main(config_yml): +@click.option('--config_yml', default="./representations/config.yml", help='Merge configuration file.') +def main(config_yml: str): with open(config_yml, "r", encoding="utf-8") as fp: - config_yml = yaml.safe_load(fp) - - model_paths = config_yml['representation_paths'] - metrics_toggle = config_yml['metrics'] - skip_layers = config_yml['block_analysis_parameters']['skip_layers'] + config = yaml.safe_load(fp) - results = Results() + model_paths = config['representation_paths'] + metrics_toggle = config['metrics'] + skip_layers = config['block_analysis_parameters']['skip_layers'] - use_metrics = {} - for metric in metrics_table: - if metrics_toggle[metric]: - use_metrics[metric] = metrics_table[metric] + use_metrics = {name: METRICS_TABLE[name] for name, enabled in metrics_toggle.items() if enabled} - device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + device = torch.device("cuda" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else "cpu") for path in model_paths: - assert Path(path).exists(), f"File not found: {path}" + if not Path(path).exists(): + raise FileNotFoundError(f"File not found: {path}") - if config_yml['compare_between_models']: + if config['compare_between_models']: if len(model_paths) != 2: raise ValueError("Expected 2 model paths for comparison") with h5py.File(model_paths[0], 'r') as representations_a, \ - h5py.File(model_paths[1], 'r') as representations_b: + h5py.File(model_paths[1], 'r') as representations_b: - # Compare corresponding layer representations - results = compare_representations(representations_a, - representations_b, - metrics_classes=use_metrics, - device=device) - - results.save('results_compare') + results_compare = compare_representations(representations_a, representations_b, + metrics_classes=use_metrics, device=device) - if config_yml['block_analysis']: - # Analyse individual layer representations - for reps_path in model_paths: - - metric_classes = list(use_metrics.values()) + results_compare.save('results_compare.pkl') - all_results = [] - for skip_layer in tqdm(skip_layers, desc='Skip Layers', total=len(skip_layers)): - all_results.append( - compute_skip_block_metrics(reps_path, skip_layer, metric_classes, device=device) - ) - - heatmaps = resultslist_to_heatmap(all_results, metric_names=[metric.__name__.lower() for metric in metric_classes]) - - for metric_name, heatmap in heatmaps.items(): - results.others[reps_path + '||' + metric_name] = heatmap - - results.save('results_block_analysis') + if config['block_analysis']: + for reps_path in model_paths: + results_block_analysis = Results() + for skip_layer in tqdm(skip_layers, desc='Skip Layers'): + results = compute_skip_block_metrics(reps_path, skip_layer, list(use_metrics.values()), device=device) + for layer_name, layer in results.layers.items(): + results_block_analysis.add_layer(layer, f"{skip_layer}_{layer_name}") - results.save('results') + results_block_analysis.save(f'results_block_analysis_{Path(reps_path).stem}.pkl') if __name__ == '__main__': main() \ No newline at end of file From 88ed18e7f6f40383d1d0a3ff267df75db7549e7c Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 11:36:59 +0100 Subject: [PATCH 38/64] clean up imports and remove hard coding --- .../visualise_representation_results.py | 29 ++----------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py index 578a7e9d..8e96c749 100644 --- a/representations/visualise_representation_results.py +++ b/representations/visualise_representation_results.py @@ -1,36 +1,13 @@ -import torch -import h5py -import random -import numpy as np import click -import yaml -import h5py - -from mergekit.config import MergeConfiguration -from mergekit.merge import MergeOptions -from mergekit.merge import run_merge -from mergekit.plot_tools.plot_tools import create_app, ResultsHandler, global_colours_list - -from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer -from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap -from mergekit.architecture import WeightInfo - - -from typing import List, Tuple, Dict -import plotly.graph_objs as go -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors -import numpy as np -from mergekit.graph import Task +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler +from mergekit.metric_methods.base import Results @click.command() @click.option('--results_path', - default="./representation_results_test.pkl", + default="./representations/results.pkl", help="path to load the results from.") def main(results_path): results = Results() - print('warning: results_path is hardcoded in main()') - results_path = '/Users/elliotstein/Documents/Arcee/mergekit/representations/results_test.pkl' results = results.load(results_path) handler = ResultsHandler() From 1041298485f643f5d11b09c3e217218210c83f55 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 11:37:12 +0100 Subject: [PATCH 39/64] tidy up tqdm --- representations/representation_metrics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 6c2ca737..56971acc 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -114,14 +114,14 @@ def compare_representations(representations_a: h5py.File, representations_b: h5p for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Comparing Representations at layer', - total=len(representations_a)): + total=len(representations_a), initial = 1): if layer_a != layer_b: raise ValueError(f'Layer mismatch: {layer_a} != {layer_b}') metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] for batch_a, batch_b in tqdm(zip(representations_a[layer_a], representations_b[layer_b]), - desc='Batch', total=len(representations_a[layer_a]), leave=False): + desc='Batch', total=len(representations_a[layer_a]), leave=False, initial = 1): batch_a = torch.tensor(representations_a[layer_a][batch_a][:], device=device) batch_b = torch.tensor(representations_b[layer_b][batch_b][:], device=device) # Calculate the metrics for each batch @@ -143,7 +143,7 @@ def compute_skip_block_metrics(reps_path: str, skip_layers: int, results = Results() with LayerByIndex(reps_path) as reps: for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {skip_layers}-block, Block Start at Layer', - total=len(reps) - skip_layers, leave=False): + total=len(reps) - skip_layers, leave=False, initial = 1): if idx + skip_layers >= len(reps): break @@ -152,7 +152,7 @@ def compute_skip_block_metrics(reps_path: str, skip_layers: int, block_end = reps[idx + skip_layers] for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', - total=len(block_start), leave=False): + total=len(block_start), leave=False, initial = 1): batch_0 = torch.tensor(block_start[batch_0][:]).to(device) batch_1 = torch.tensor(block_end[batch_1][:]).to(device) @@ -205,9 +205,9 @@ def main(config_yml: str): results_compare.save('results_compare.pkl') if config['block_analysis']: - for reps_path in model_paths: + for reps_path in tqdm(model_paths, desc='Model', leave=False, total=len(model_paths), initial = 1): results_block_analysis = Results() - for skip_layer in tqdm(skip_layers, desc='Skip Layers'): + for skip_layer in tqdm(skip_layers, desc='Skip Layers', initial = 1): results = compute_skip_block_metrics(reps_path, skip_layer, list(use_metrics.values()), device=device) for layer_name, layer in results.layers.items(): results_block_analysis.add_layer(layer, f"{skip_layer}_{layer_name}") From 01d5a2b99d0c1e37ac2a983f8879b683747ba96d Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 11:37:29 +0100 Subject: [PATCH 40/64] improve robustness of load and save using pathlib --- mergekit/metric_methods/base.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index a166a31d..1a8a0c73 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -22,7 +22,7 @@ from dataclasses import dataclass, field from collections import defaultdict import torch - +from pathlib import Path import pickle class MetricMethod(MergeMethod): @@ -188,12 +188,19 @@ def print_metric_summary(self): print(f" Has model reference: {'Yes' if info['has_model_ref'] else 'No'}") def save(self, path: str): - path = path + '.pkl' if not path.endswith('.pkl') else path - with open(path, 'wb') as f: + path = Path(path) + if not path.suffix or path.suffix != '.pkl': + path = path.with_suffix('.pkl') + + with path.open('wb') as f: pickle.dump(self, f) def load(self, path: str): - with open(path, 'rb') as f: - results = pickle.load(f) - assert isinstance(results, Results), "Loaded object is not a Results object" - return results \ No newline at end of file + path_obj = Path(path) + if path_obj.exists() and path_obj.is_file(): + with open(path_obj, 'rb') as f: + results = pickle.load(f) + assert isinstance(results, Results), "Loaded object is not a Results object" + return results + else: + raise FileNotFoundError(f"The path {path} does not exist or is not a file.") \ No newline at end of file From f64a71ec4fccf050b7509447a27ffbf89794245e Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 12:43:06 +0100 Subject: [PATCH 41/64] reintroduced heatmap functionality --- representations/representation_metrics.py | 49 +++++++++++++++++------ 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 56971acc..7d025bb6 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -109,8 +109,9 @@ def __iter__(self): return iter(self.representations[layer] for layer in self.layers) def compare_representations(representations_a: h5py.File, representations_b: h5py.File, - metrics_classes: Dict[str, MetricAggregator], device: str) -> Dict[str, Any]: - results = Results() + metrics_classes: Dict[str, MetricAggregator], device: str, results: Results) -> Dict[str, Any]: + if results is None: + results = Results() for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Comparing Representations at layer', @@ -124,6 +125,7 @@ def compare_representations(representations_a: h5py.File, representations_b: h5p desc='Batch', total=len(representations_a[layer_a]), leave=False, initial = 1): batch_a = torch.tensor(representations_a[layer_a][batch_a][:], device=device) batch_b = torch.tensor(representations_b[layer_b][batch_b][:], device=device) + # Calculate the metrics for each batch for metric in metrics: metric.process_batch(batch_a, batch_b) @@ -162,12 +164,29 @@ def compute_skip_block_metrics(reps_path: str, skip_layers: int, # Aggregate metrics and add to results layer_results = Layer(WeightInfo(name=f"Layer {idx}")) for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower().lower()) + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) results.add_layer(layer_results, f"Layer {idx}") return results +def results_list_to_heatmap(all_results, metric_names:List[str]) -> dict: + rows = len(all_results) + cols = max([len(result.layers) for result in all_results]) + heatmaps = {} + for metric_name in metric_names: + heatmap = np.full((rows, cols), np.nan) + + for i, result in enumerate(all_results): + for j, layer in enumerate(result.layers): + heatmap[i, j] = result.layers[layer].metrics[metric_name][0].mean_std.mean + heatmaps[metric_name] = Heatmap(data=heatmap, + update_layout_options = { + 'xaxis_title': 'Layer Number', + 'yaxis_title': 'Block Size', + }) + return heatmaps + METRICS_TABLE = { 'cosine_similarity': Cosine_Similarity, 'mse': MSE @@ -187,11 +206,12 @@ def main(config_yml: str): device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") - + for path in model_paths: if not Path(path).exists(): raise FileNotFoundError(f"File not found: {path}") + all_results = Results() if config['compare_between_models']: if len(model_paths) != 2: raise ValueError("Expected 2 model paths for comparison") @@ -199,20 +219,23 @@ def main(config_yml: str): with h5py.File(model_paths[0], 'r') as representations_a, \ h5py.File(model_paths[1], 'r') as representations_b: - results_compare = compare_representations(representations_a, representations_b, - metrics_classes=use_metrics, device=device) - - results_compare.save('results_compare.pkl') + all_results = compare_representations(representations_a, representations_b, + metrics_classes=use_metrics, device=device, results=all_results) if config['block_analysis']: for reps_path in tqdm(model_paths, desc='Model', leave=False, total=len(model_paths), initial = 1): - results_block_analysis = Results() + results_list = [] + metric_classes = list(use_metrics.values()) for skip_layer in tqdm(skip_layers, desc='Skip Layers', initial = 1): - results = compute_skip_block_metrics(reps_path, skip_layer, list(use_metrics.values()), device=device) - for layer_name, layer in results.layers.items(): - results_block_analysis.add_layer(layer, f"{skip_layer}_{layer_name}") + results_list.append( + compute_skip_block_metrics(reps_path, skip_layer, metric_classes=metric_classes, device=device) + ) + + heatmaps = results_list_to_heatmap(results_list, metric_names=[metric.__name__.lower() for metric in metric_classes]) + for metric_name, heatmap in heatmaps.items(): + all_results.others[reps_path + '||' + metric_name] = heatmap - results_block_analysis.save(f'results_block_analysis_{Path(reps_path).stem}.pkl') + all_results.save('results.pkl') if __name__ == '__main__': main() \ No newline at end of file From 8b05c2af08d1c9dfdb5f050b241fccbb3f612e79 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 12:43:22 +0100 Subject: [PATCH 42/64] allow for plot keyworks to be passed into Heatmap object --- mergekit/metric_methods/base.py | 1 + mergekit/plot_tools/plot_tools.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 1a8a0c73..67cef20f 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -55,6 +55,7 @@ class MeanStd: @dataclass class Heatmap: data: torch.Tensor + update_layout_options: Optional[Dict] = None @dataclass class Histogram: diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 3f9ce101..e1733df1 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -254,7 +254,7 @@ def create_heatmap_section(results_handler): heatmap_sections.append(html.Div([ html.H3(f'Heatmap: {title}', style={'textAlign': 'center'}), - dcc.Graph(id=f'heatmap-plot-{i}', style={'width': '30%', 'height': '30%'}) + dcc.Graph(id=f'heatmap-plot-{i}', style={'width': '50%', 'height': '50%', 'position': 'relative'}) ], className='container-fluid')) return html.Div(heatmap_sections) else: @@ -377,11 +377,15 @@ def update_heatmap_plot(_key=key): zmax=np.nanmax(heatmap_data), # Set the scale max to the max data value colorbar=dict(title='Scale') # Customize the color bar )) - + default_layout_options = { + 'xaxis_title':"X Axis", + 'yaxis_title':"Y Axis" + } + if results_handler.results.others[key].update_layout_options: + default_layout_options.update(results_handler.results.others[key].update_layout_options) fig.update_layout( title=f"Heatmap: {_key}", - xaxis_title="X Axis", - yaxis_title="Y Axis" + **default_layout_options ) return fig From b8f7e325bc6119aa84f15f1ffbe712babf5d90de Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 12:44:23 +0100 Subject: [PATCH 43/64] example config --- representations/config.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 representations/config.yml diff --git a/representations/config.yml b/representations/config.yml new file mode 100644 index 00000000..b2260d20 --- /dev/null +++ b/representations/config.yml @@ -0,0 +1,13 @@ +representation_paths: +- ./representations/Representations_BEE-spoke-data_smol_llama-220M-GQA_train_4000.h5 +- ./representations/Representations_BEE-spoke-data_smol_llama-220M-openhermes_train_4000.h5 + +metrics: + cosine_similarity: true + mse: true + +compare_between_models: true +block_analysis: true +block_analysis_parameters: + skip_layers: [1,2,3,4,5,6,7,8,9] +analyse_individual_models: true \ No newline at end of file From c4df572dd88fccf324cf0624845c1700a2701d3b Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 9 Jul 2024 16:36:04 +0100 Subject: [PATCH 44/64] improve implementation consistency --- representations/representation_metrics.py | 58 +++++++++++++---------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 7d025bb6..4fe29013 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -107,38 +107,44 @@ def __len__(self) -> int: def __iter__(self): return iter(self.representations[layer] for layer in self.layers) - -def compare_representations(representations_a: h5py.File, representations_b: h5py.File, + +def compare_representations(reps_path_a: str, reps_path_b: str, metrics_classes: Dict[str, MetricAggregator], device: str, results: Results) -> Dict[str, Any]: if results is None: results = Results() - for layer_a, layer_b in tqdm(zip(representations_a, representations_b), - desc='Comparing Representations at layer', - total=len(representations_a), initial = 1): - if layer_a != layer_b: - raise ValueError(f'Layer mismatch: {layer_a} != {layer_b}') - - metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] + with LayerByIndex(reps_path_a) as representations_a, \ + LayerByIndex(reps_path_b) as representations_b: - for batch_a, batch_b in tqdm(zip(representations_a[layer_a], representations_b[layer_b]), - desc='Batch', total=len(representations_a[layer_a]), leave=False, initial = 1): - batch_a = torch.tensor(representations_a[layer_a][batch_a][:], device=device) - batch_b = torch.tensor(representations_b[layer_b][batch_b][:], device=device) + for layer_a, layer_b in tqdm(zip(representations_a, representations_b), + desc='Comparing Representations at layer', + total=len(representations_a), initial = 1): - # Calculate the metrics for each batch - for metric in metrics: - metric.process_batch(batch_a, batch_b) + layer_a_name = layer_a.name.split('/')[-1] + layer_b_name = layer_b.name.split('/')[-1] + if layer_a_name != layer_b_name: + raise ValueError(f'Layer mismatch: {layer_a_name} != {layer_b_name}') - layer_results = Layer(WeightInfo(name=layer_a)) - # Aggregate over the batches and add to the layer results - for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - metric.clear() + metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] - results.add_layer(layer_results, layer_a) + for batch_a, batch_b in tqdm(zip(layer_a, layer_b), + desc='Batch', total=len(layer_a), leave=False, initial = 1): + batch_a = torch.tensor(layer_a[batch_a][:], device=device) + batch_b = torch.tensor(layer_b[batch_b][:], device=device) + + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch_a, batch_b) - return results + layer_results = Layer(WeightInfo(name=layer_a_name)) + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + metric.clear() + + results.add_layer(layer_results, layer_a_name) + + return results def compute_skip_block_metrics(reps_path: str, skip_layers: int, metric_classes: List[MetricAggregator], device: str) -> Results: @@ -216,10 +222,10 @@ def main(config_yml: str): if len(model_paths) != 2: raise ValueError("Expected 2 model paths for comparison") - with h5py.File(model_paths[0], 'r') as representations_a, \ - h5py.File(model_paths[1], 'r') as representations_b: + # with h5py.File(model_paths[0], 'r') as representations_a, \ + # h5py.File(model_paths[1], 'r') as representations_b: - all_results = compare_representations(representations_a, representations_b, + all_results = compare_representations(model_paths[0], model_paths[1], metrics_classes=use_metrics, device=device, results=all_results) if config['block_analysis']: From 5e051a8be7dd2a7746219491568eb56c6fcaa591 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 16 Jul 2024 13:22:20 +0100 Subject: [PATCH 45/64] experimental linearity score metric --- representations/config.yml | 1 + representations/representation_metrics.py | 61 +++++++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/representations/config.yml b/representations/config.yml index b2260d20..9c60ca59 100644 --- a/representations/config.yml +++ b/representations/config.yml @@ -5,6 +5,7 @@ representation_paths: metrics: cosine_similarity: true mse: true + linearity_score: true compare_between_models: true block_analysis: true diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 4fe29013..4cb264f6 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -84,6 +84,61 @@ def aggregate(self) -> Metric: def clear(self) -> None: self.square_errors = torch.tensor([], device=self.device) +class Linearity_Score(MetricAggregator): + def __init__(self, device: str = "cpu"): + self.device = device + self.iterations = 0 + self.max_iterations = 250 + self.A = None + self.optimiser = None + self.initialised = False + self.done = False + self.losses = [] + + self.absolute_square_sum = 0 + self.num_elements = 0 + + def _first_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_size, dimension = batch_a.size() + + self.A = torch.empty(dimension, dimension, device=self.device) + torch.nn.init.normal_(self.A) + self.A = torch.nn.Parameter(self.A) + + self.optimiser = torch.optim.SGD([self.A], lr=0.0001) + self.initialised = True + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_a = batch_a / torch.norm(batch_a, dim=1, keepdim=True) + batch_b = batch_b / torch.norm(batch_b, dim=1, keepdim=True) # Check dimensionality (X) + if not self.initialised: + self._first_batch(batch_a, batch_b) + if self.done: # stop training A and evaluate + residuals = batch_a @ self.A - batch_b + self.absolute_square_sum += torch.abs(residuals).sum().item() + self.num_elements += residuals.numel() + + else: + + loss = torch.norm(batch_a @ self.A - batch_b) ** 2 + loss.backward() + self.losses.append(loss.item()) + print(f'Loss: {loss.item()}') + self.optimiser.step() + + self.iterations += 1 + + if self.iterations >= self.max_iterations: + self.done = True + + def aggregate(self) -> Metric: + linearity_score = 1 - self.absolute_square_sum / self.num_elements + self.clear() + return Metric(mean_std=MeanStd(mean=linearity_score)) + + def clear(self) -> None: + pass + class LayerByIndex: def __init__(self, reps_path: str): self.reps_path = reps_path @@ -195,7 +250,8 @@ def results_list_to_heatmap(all_results, metric_names:List[str]) -> dict: METRICS_TABLE = { 'cosine_similarity': Cosine_Similarity, - 'mse': MSE + 'mse': MSE, + 'linearity_score': Linearity_Score } @click.command() @@ -222,9 +278,6 @@ def main(config_yml: str): if len(model_paths) != 2: raise ValueError("Expected 2 model paths for comparison") - # with h5py.File(model_paths[0], 'r') as representations_a, \ - # h5py.File(model_paths[1], 'r') as representations_b: - all_results = compare_representations(model_paths[0], model_paths[1], metrics_classes=use_metrics, device=device, results=all_results) From bb8fce27dda66a0282a821f0cb0818c78490d61e Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 16 Jul 2024 13:22:43 +0100 Subject: [PATCH 46/64] add matplotlib to optional dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5afa4bb5..320e7dc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dev = ["black~=24.4.2", "isort~=5.13.2", "pre-commit~=3.7.1"] test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] -interactive_plot = ["networkx", "plotly"] +interactive_plot = ["networkx", "plotly", matplotlib"] [project.urls] repository = "https://github.com/cg123/mergekit" From e1c1ecdc4ed75f23b0d4802e5816261295c911ac Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 16 Jul 2024 16:07:46 +0100 Subject: [PATCH 47/64] remove quantisation and update environment reqs --- pyproject.toml | 3 ++- representations/store_representations.py | 7 ------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fe167b36..ca8ea9f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dev = ["black~=24.4.2", "isort~=5.13.2", "pre-commit~=3.7.1"] test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] -interactive_plot = ["networkx", "plotly", matplotlib"] +interactive_plot = ["networkx", "plotly", "matplotlib"] +representations = ["h5py", "datasets", "bitsandbytes"] [project.urls] repository = "https://github.com/cg123/mergekit" diff --git a/representations/store_representations.py b/representations/store_representations.py index f0d3c16f..f98cc224 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -75,12 +75,6 @@ def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, else "mps" if torch.backends.mps.is_available() \ else "cpu" - # if resource is a problem - quantization_config = BitsAndBytesConfig(load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16) - dataset = datasets.load_dataset(dataset, split=dataset_subset) if dataset_size: dataset = dataset.select(range(dataset_size)) @@ -88,7 +82,6 @@ def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", - quantization_config=quantization_config if device == "cuda" else None, output_hidden_states=True) From 9db9d3bc360cd959e0f19abf3eb3c65979b203e0 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Mon, 22 Jul 2024 15:38:26 +0100 Subject: [PATCH 48/64] tidy up and add missing dependency --- pyproject.toml | 2 +- representations/store_representations.py | 32 ++++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca8ea9f0..8814984d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dev = ["black~=24.4.2", "isort~=5.13.2", "pre-commit~=3.7.1"] test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] -interactive_plot = ["networkx", "plotly", "matplotlib"] +interactive_plot = ["networkx", "plotly", "matplotlib", "dash"] representations = ["h5py", "datasets", "bitsandbytes"] [project.urls] diff --git a/representations/store_representations.py b/representations/store_representations.py index f98cc224..c2c1964c 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -60,31 +60,20 @@ def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tens return last_non_padded_hidden_states -@click.command() -@click.option('--model_path', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') -@click.option('--output_path', default="./representations/", help='folder to store the result in.') -@click.option('--dataset', default="arcee-ai/sec-data-mini", help='dataset to use.') -@click.option('--batch_size', default=8, help='batch size.') -@click.option('--max_length', default=1024, help='maximum length of the input.') -@click.option('--dataset_size', default=4000, help='size of the dataset.') -@click.option('--dataset_column', default="text", help='column of the dataset to use.') -@click.option('--dataset_subset', default="train", help='subset of the dataset to use.') -def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, dataset_column, dataset_subset): +def store_representations(model_path, output_path, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): device = "cuda" if torch.cuda.is_available() \ else "mps" if torch.backends.mps.is_available() \ else "cpu" - dataset = datasets.load_dataset(dataset, split=dataset_subset) + dataset = datasets.load_dataset(dataset_name, split=dataset_subset) if dataset_size: dataset = dataset.select(range(dataset_size)) - model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", output_hidden_states=True) - tokenizer = AutoTokenizer.from_pretrained(model_path) if not tokenizer.pad_token: @@ -96,10 +85,10 @@ def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) - output_name = f'Representations_{model.name_or_path.replace("/","_")}_{dataset_subset}_{dataset_size}' - assert not os.path.exists(output_path+f'{output_name}.h5'), f'{output_name}.h5 already exists.' + output_name = f'Representations_{model.name_or_path}_{dataset_name}_{dataset_size}.h5'.replace("/","_") + assert not os.path.exists(output_name), f'{output_name} already exists.' - with h5py.File(f'{output_name}.h5', 'w') as h5file: + with h5py.File(output_name, 'w') as h5file: for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")): inputs = tokenizer(batch, return_tensors="pt", padding="longest", max_length=max_length, truncation=True).to(device) with torch.no_grad(): @@ -122,5 +111,16 @@ def main(model_path, output_path, dataset, batch_size, max_length, dataset_size, does not match expected number of hidden layers." +@click.command() +@click.option('--model_path', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') +@click.option('--output_path', default="./representations/", help='folder to store the result in.') +@click.option('--dataset_name', default="arcee-ai/sec-data-mini", help='dataset to use.') +@click.option('--batch_size', default=8, help='batch size.') +@click.option('--max_length', default=1024, help='maximum length of the input.') +@click.option('--dataset_size', default=4000, help='size of the dataset.') +@click.option('--dataset_column', default="text", help='column of the dataset to use.') +@click.option('--dataset_subset', default="train", help='subset of the dataset to use.') +def main(model_path, output_path, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + store_representations(model_path, output_path, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset) if __name__ == "__main__": main() From af2bf1afd640238382cb9b00b33a4ebfcba98358 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 25 Jul 2024 12:37:59 +0100 Subject: [PATCH 49/64] Major restructuring of Results, Results handling, metrics --- examples/metrics-small.yml | 1 + mergekit/metric_methods/all_metrics.py | 115 +++--- mergekit/metric_methods/base.py | 155 +++++--- mergekit/metric_methods/metrics.py | 79 ++-- mergekit/plot_tools/plot_tools.py | 429 ++++++++++++++-------- mergekit/scripts/run_metrics.py | 71 +++- representations/representation_metrics.py | 115 +++++- representations/representations.py | 156 ++++++++ 8 files changed, 802 insertions(+), 319 deletions(-) create mode 100644 representations/representations.py diff --git a/examples/metrics-small.yml b/examples/metrics-small.yml index 6c6ec1a9..85e3b5ef 100644 --- a/examples/metrics-small.yml +++ b/examples/metrics-small.yml @@ -5,4 +5,5 @@ models: metric_method: all parameters: intra_model_metrics: true + inter_model_metrics: true dtype: float32 diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py index b78a25f7..0782dde4 100644 --- a/mergekit/metric_methods/all_metrics.py +++ b/mergekit/metric_methods/all_metrics.py @@ -101,19 +101,20 @@ def execute( self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs ) -> torch.Tensor: weights = list(tensors.values()) - validate_tensors(weights, self.weight_info, expected_tensors=2) layer_results = Layer(metrics={}, weight_info=self.weight_info) - layer_results.add_metric(cosine_similarity(weights, return_heatmap=False), name = 'cosine_similarity') - layer_results.add_metric(smape(weights), name = 'smape') - layer_results.add_metric(scale(weights, return_heatmap=False), name = 'scale') - layer_results.add_metric(mse(weights, return_heatmap=False), name = 'mse') - if self.intra_model_metrics: - model_refs = list(tensors.keys()) - layer_results.add_metric_list(metric_list=weight_magnitude(weights, model_refs), name='weight_magnitude') - layer_results.add_metric_list(metric_list=numerical_rank(weights, model_refs), name='numerical_rank') + validate_tensors(weights, self.weight_info, expected_tensors=1) + layer_results.add_metric(weight_magnitude(weights[0]), name='weight_magnitude') + layer_results.add_metric(numerical_rank(weights[0]), name='numerical_rank') + else: + validate_tensors(weights, self.weight_info, expected_tensors=2) + layer_results.add_metric(cosine_similarity(weights, return_heatmap=False), name = 'cosine_similarity') + layer_results.add_metric(smape(weights), name = 'smape') + layer_results.add_metric(scale(weights, return_heatmap=False), name = 'scale') + layer_results.add_metric(mse(weights, return_heatmap=False), name = 'mse') + return layer_results @@ -145,46 +146,46 @@ def execute( v_proj[model_references[0]], o_proj[model_references[0]], self.weight_info) - k_proj_1, v_proj_1, q_proj_1, o_proj_1 = group_attn_head_weights(k_proj[model_references[1]], - q_proj[model_references[1]], - v_proj[model_references[1]], - o_proj[model_references[1]], - self.weight_info) - - # Metrics for K, V, Q, O projections - - - # Metrics for heads - model_0_heads = torch.cat([k_proj_0, v_proj_0, q_proj_0, o_proj_0], dim=1) - model_1_heads = torch.cat([k_proj_1, v_proj_1, q_proj_1, o_proj_1], dim=1) - layer_results = Layer(metrics={}, weight_info=self.weight_info) - - - layer_results.add_metric(cosine_similarity([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True), - name = 'cosine_similarity') - layer_results.add_metric(smape([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)]), - name = 'smape') - layer_results.add_metric(scale([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=True), - name = 'scale') - layer_results.add_metric(mse([model_0_heads.view(model_0_heads.shape[0], -1), - model_1_heads.view(model_1_heads.shape[0], -1)], - return_heatmap=False), - name = 'mse') + if self.intra_model_metrics: - layer_results.add_metric_list( - metric_list=weight_magnitude([model_0_heads, model_1_heads], model_refs=model_references), + layer_results.add_metric( + metric=weight_magnitude(model_0_heads), name='weight_magnitude' ) + else: + + k_proj_1, v_proj_1, q_proj_1, o_proj_1 = group_attn_head_weights(k_proj[model_references[1]], + q_proj[model_references[1]], + v_proj[model_references[1]], + o_proj[model_references[1]], + self.weight_info) + + + model_1_heads = torch.cat([k_proj_1, v_proj_1, q_proj_1, o_proj_1], dim=1) + + + + layer_results.add_metric(cosine_similarity([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True), + name = 'cosine_similarity') + layer_results.add_metric(smape([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)]), + name = 'smape') + layer_results.add_metric(scale([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True), + name = 'scale') + layer_results.add_metric(mse([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=False), + name = 'mse') + return layer_results @@ -203,6 +204,8 @@ def __eq__(self, other): class LayerNormTask(Task[torch.Tensor]): gather_tensors: GatherTensors weight_info: WeightInfo + intra_model_metrics: bool = False + def uses_accelerator(self) -> bool: return True @@ -216,14 +219,20 @@ def execute( tensors = list(tensors.values()) assert tensors[0].dim() == 1, "LayerNorm tensors must be 2D" - assert tensors[1].dim() == 1, "LayerNorm tensors must be 2D" layer_results = Layer(metrics={}, weight_info=self.weight_info) - - layer_results.add_metric(cosine_similarity([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'cosine_similarity') - layer_results.add_metric(smape([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)]), name = 'smape') - layer_results.add_metric(scale([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'scale') - layer_results.add_metric(mse([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'mse') + + if self.intra_model_metrics: + layer_results.add_metric( + metric=weight_magnitude(tensors[0].unsqueeze(1)), + name='weight_magnitude' + ) + else: + assert tensors[1].dim() == 1, "LayerNorm tensors must be 2D" + layer_results.add_metric(cosine_similarity([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'cosine_similarity') + layer_results.add_metric(smape([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)]), name = 'smape') + layer_results.add_metric(scale([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'scale') + layer_results.add_metric(mse([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'mse') return layer_results @@ -253,10 +262,12 @@ def group_label(self) -> Optional[str]: # Metric method class AllMetric(MetricMethod): - attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} - attn_info_dict: Optional[Dict[str, WeightInfo]] = {} - attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] # hard-coded for now - block_count: Optional[int] = 0 + def __init__(self) -> None: + super().__init__() + self.attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} + self.attn_info_dict: Optional[Dict[str, WeightInfo]] = {} + self.attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] # hard-coded for now + self.block_count: Optional[int] = 0 def make_task( self, @@ -276,7 +287,7 @@ def make_task( intra_model_metrics=parameters['intra_model_metrics'] ) elif 'layernorm' in output_weight.name: - return LayerNormTask(gather_tensors=tensors, weight_info=output_weight) + return LayerNormTask(gather_tensors=tensors, weight_info=output_weight, intra_model_metrics=parameters['intra_model_metrics']) else: # Executor expects a task to be returned return DummyTask(gather_tensors=tensors, weight_info=output_weight) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 67cef20f..48dd7cf2 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -29,6 +29,7 @@ class MetricMethod(MergeMethod): pass # Structure of the results object +# OLD # Results # └── layers: Dict[str, Layer] @@ -46,6 +47,26 @@ class MetricMethod(MergeMethod): # For metrics which compare between models, (e.g. cosine similarity) the list will contain a single Metric object storing the comparison data. # For metrics which analyse individual models, (e.g. intrinsic dimension) the list will contain a Metric object for each model. +# New + +# Results +# ├── model_ref: Optional[List[ModelReference]] # One for individual model, two for comparison +# └── layers: Dict[str, Layer] +# └── Layer +# ├── weight_info: WeightInfo (remove?) +# └── metrics: Dict[str, Metric] +# └── Metric +# ├── histogram: Optional[Histogram] +# ├── mean_std: Optional[MeanStd] +# ├── scatter_plot: Optional[ScatterPlot] +# └── heatmap: Optional[Heatmap] +from enum import Enum + +class PlotType(Enum): + HISTOGRAM = 'histogram' + MEAN_STD = 'mean_std' + SCATTER_PLOT = 'scatter_plot' + HEATMAP = 'heatmap' @dataclass class MeanStd: @@ -63,12 +84,17 @@ class Histogram: edges: List[float] widths: List[float] +@dataclass +class ScatterPlot: + x: List[float] + y: List[float] + @dataclass class Metric: histogram: Optional[Histogram] = None mean_std: Optional[MeanStd] = None heatmap: Optional[Heatmap] = None - model_ref: Optional[ModelReference] = None # For intra-model metrics. + scatter_plot: Optional[ScatterPlot] = None def filled_attributes(self) -> List[str]: filled_attrs = [] @@ -80,20 +106,20 @@ def filled_attributes(self) -> List[str]: @dataclass class Layer: weight_info: WeightInfo - metrics: Dict[str, List[Metric]] = field(default_factory=dict) + metrics: Dict[str, Metric] = field(default_factory=dict) def metrics_with_attribute(self, attribute: str) -> List[str]: - return [name for name, metric in self.metrics.items() if attribute in metric[0].filled_attributes()] + return [name for name, metric in self.metrics.items() if attribute in metric.filled_attributes()] def add_metric(self, metric: Metric, name: str): if name not in self.metrics.keys(): - self.metrics[name] = [metric] + self.metrics[name] = metric else: - self.metrics[name].append(metric) + raise ValueError(f"Metric with name {name} already exists in layer {self.weight_info.layer_name}.") - def add_metric_list(self, metric_list: List[Metric], name: str): - for metric in metric_list: - self.add_metric(metric, name) + # def add_metric_list(self, metric_list: List[Metric], name: str): + # for metric in metric_list: + # self.add_metric(metric, name) def expand_to_fit(all_layer_names: List[str], values: List[float], subset_layer_names: List[str]) -> List[float]: """ @@ -115,38 +141,54 @@ def expand_to_fit(all_layer_names: List[str], values: List[float], subset_layer_ result[i] = subset_dict[layer] return result + +from typing import List, Tuple +from mergekit.graph import Task class Results: # Class to store the statistics for each layer def __init__(self): self.layers: Dict[str, Layer] = {} - self.others: Dict[str, Metric] = {} - + self.across_layer_metrics: Dict[str, Metric] = {} + self.model_refs: Optional[List[ModelReference]] = None + def add_layer(self, layer: Layer, name: str): if name not in self.layers.keys(): self.layers[name] = layer - def get_metric(self, layer_name: str, metric_name: str) -> Metric: - return self.get_layer(layer_name, metric_name) + # def get_metric(self, layer_name: str, metric_name: str) -> Metric: + # return self.get_layer(layer_name, metric_name) # Doesnt' Work! (X) + def load_metrics(self, metrics: List[Tuple[Task, Layer]], model_refs: Optional[List[ModelReference]] = None): + self.model_refs = model_refs + for task, metric in metrics: + if metric is not None: + self.add_layer(metric, name=task.weight_info.name) + return self def get_lineplot_data(self, metric_name: str): - means, stds = defaultdict(list), defaultdict(list) - layers = [] - - for name, layer in self.layers.items(): - if metric_name in layer.metrics: - for model_result in layer.metrics[metric_name]: - model_ref = model_result.model_ref if model_result.model_ref else 'all' - means[model_ref].append(model_result.mean_std.mean) - stds[model_ref].append(model_result.mean_std.std) - layers.append(name) - - means_list, stds_list, model_references = list(means.values()), list(stds.values()), list(means.keys()) - for i, model_ref in enumerate(model_references): - means_list[i] = expand_to_fit(all_layer_names=list(self.layers.keys()), values=means_list[i], subset_layer_names=layers) - stds_list[i] = expand_to_fit(all_layer_names=list(self.layers.keys()), values=stds_list[i], subset_layer_names=layers) - - return means_list, stds_list, model_references + means, stds = [],[] + + available_line_plots = self.available_plot_types(PlotType.MEAN_STD.value) + assert metric_name in available_line_plots, f"Metric {metric_name} does not have mean/std data available." + + layers_with_data = available_line_plots[metric_name] + means = [self.layers[layer].metrics[metric_name].mean_std.mean for layer in layers_with_data] + stds = [self.layers[layer].metrics[metric_name].mean_std.std for layer in layers_with_data] + + # for name, layer in self.layers.items(): + # if metric_name in layer.metrics: + # # for model_result in layer.metrics[metric_name]: + # # model_ref = model_result.model_refs if model_result.model_refs else 'all' + # means.append(layer.metrics[metric_name].mean_std.mean) + # stds.append(layer.metrics[metric_name].mean_std.std) + # layers.append(name) + + # # means_list, stds_list, model_references = list(means.values()), list(stds.values()), list(means.keys()) + # # for i, model_ref in enumerate(model_references): + means = expand_to_fit(all_layer_names=list(self.layers.keys()), values=means, subset_layer_names=layers_with_data) + stds = expand_to_fit(all_layer_names=list(self.layers.keys()), values=stds, subset_layer_names=layers_with_data) + + return means, stds def available_metrics(self) -> Dict[str, Dict[str, Any]]: all_metrics = set() @@ -157,25 +199,42 @@ def available_metrics(self) -> Dict[str, Dict[str, Any]]: for metric in all_metrics: info = { 'layers': [], - 'has_mean_std': False, - 'has_histogram': False, - 'has_heatmap': False, - 'has_model_ref': False + PlotType.MEAN_STD.value: False, + PlotType.HISTOGRAM.value: False, + PlotType.HEATMAP.value: False, + PlotType.SCATTER_PLOT.value: False } for layer_name, layer in self.layers.items(): if metric in layer.metrics: info['layers'].append(layer_name) - for m in layer.metrics[metric]: - if m.mean_std: - info['has_mean_std'] = True - if m.histogram: - info['has_histogram'] = True - if m.heatmap: - info['has_heatmap'] = True - if m.model_ref: - info['has_model_ref'] = True + m = layer.metrics[metric] + if m.mean_std: + info[PlotType.MEAN_STD.value] = True + if m.histogram: + info[PlotType.HISTOGRAM.value] = True + if m.heatmap: + info[PlotType.HEATMAP.value] = True + if m.scatter_plot: + info[PlotType.SCATTER_PLOT.value] = True metric_info[metric] = info return metric_info + + def available_plot_types(self, plot_type: str) -> Dict[str, List[str]]: + # Returns dictionary with key metric_name and value: list of layers for which that metric has data + metric_info = self.available_metrics() + out = {} + plot_type = 'mean_std' if plot_type == 'line_plot' else plot_type + assert plot_type in [p.value for p in PlotType], f"Plot type {plot_type} is not valid. Must be one of {[p.value for p in PlotType]}" + for metric_name, info in metric_info.items(): + if info[plot_type]: + out[metric_name] = info['layers'] + return out + + def available_metrics_at_layer(self, layer_name: str) -> List[str]: + if layer_name in self.layers: + return list(self.layers[layer_name].metrics.keys()) + else: + return [] def print_metric_summary(self): metric_info = self.available_metrics() @@ -183,10 +242,14 @@ def print_metric_summary(self): for metric, info in metric_info.items(): print(f"\nMetric: {metric}") # print(f" Available in layers: {', '.join(info['layers'])}") - print(f" Has mean/std: {'Yes' if info['has_mean_std'] else 'No'}") - print(f" Has histogram: {'Yes' if info['has_histogram'] else 'No'}") - print(f" Has heatmap: {'Yes' if info['has_heatmap'] else 'No'}") - print(f" Has model reference: {'Yes' if info['has_model_ref'] else 'No'}") + print(f" Has mean/std: {'Yes' if info[PlotType.MEAN_STD.value] else 'No'}") + print(f" Has histogram: {'Yes' if info[PlotType.HISTOGRAM.value] else 'No'}") + print(f" Has heatmap: {'Yes' if info[PlotType.HEATMAP.value] else 'No'}") + print(f" Has scatter plot: {'Yes' if info[PlotType.SCATTER_PLOT.value] else 'No'}") + + def finalise(self): + self.layer_names = list(self.layers.keys()) + self.metric_names = list(set([metric for layer in self.layers.values() for metric in layer.metrics.keys()])) def save(self, path: str): path = Path(path) diff --git a/mergekit/metric_methods/metrics.py b/mergekit/metric_methods/metrics.py index ad3b2060..8dc3618b 100644 --- a/mergekit/metric_methods/metrics.py +++ b/mergekit/metric_methods/metrics.py @@ -135,23 +135,19 @@ def mse( # Tensor Analysis (number of tensors can vary) -def weight_magnitude(tensors: List[torch.Tensor], model_refs: List[ModelReference]) -> List[Metric]: - output = [] - for tensor, model_reference in zip(tensors, model_refs): - weight_magnitudes = torch.abs(tensor.flatten()) - hist_info = compute_histogram(weight_magnitudes, 100) - output.append(Metric( - histogram=Histogram(count=hist_info[0], - edges=hist_info[1], - widths=hist_info[2] - ), - mean_std=MeanStd(mean=weight_magnitudes.mean().item(), - std=weight_magnitudes.std().item()), - model_ref=model_reference - )) - return output - -def numerical_rank(tensors: List[torch.Tensor], model_refs: List[ModelReference], epsilon: float = 1e-5) -> List[Metric]: +def weight_magnitude(tensor: torch.Tensor) -> Metric: + weight_magnitudes = torch.abs(tensor.flatten()) + hist_info = compute_histogram(weight_magnitudes, 100) + return Metric( + histogram=Histogram(count=hist_info[0], + edges=hist_info[1], + widths=hist_info[2] + ), + mean_std=MeanStd(mean=weight_magnitudes.mean().item(), + std=weight_magnitudes.std().item()), + ) + +def numerical_rank(tensor: torch.Tensor, epsilon: float = 1e-5) -> Metric: """ Computes the numerical rank of the representations matrix X based on the singular values of its sample covariance matrix. The rank is determined as the number of singular values @@ -172,31 +168,26 @@ def numerical_rank(tensors: List[torch.Tensor], model_refs: List[ModelReference] https://arxiv.org/pdf/2305.19753.pdf """ - output = [] - for tensor, model_reference in zip(tensors, model_refs): - # Center the data by subtracting the mean - X_centered = tensor - torch.mean(tensor, dim=0) - X_std = torch.std(X_centered, dim=0, unbiased=False) - X_centered /= X_std - - # Compute the sample covariance matrix - covariance_matrix = X_centered.t() @ X_centered / (tensor.shape[0] - 1) - # Compute singular values using SVD on the covariance matrix - U, singular_values, V = torch.svd(covariance_matrix.cpu()) - # Determine the threshold - threshold = singular_values[0] * epsilon - # Count singular values greater than the threshold - num_rank = torch.sum(singular_values > threshold).item() - - value = int(num_rank) - - output.append( - Metric( - model_ref=model_reference, - mean_std=MeanStd( - mean=value, - std=None), - )) - - return output + # Center the data by subtracting the mean + X_centered = tensor - torch.mean(tensor, dim=0) + X_std = torch.std(X_centered, dim=0, unbiased=False) + X_centered /= X_std + + # Compute the sample covariance matrix + covariance_matrix = X_centered.t() @ X_centered / (tensor.shape[0] - 1) + # Compute singular values using SVD on the covariance matrix + U, singular_values, V = torch.svd(covariance_matrix.cpu()) + # Determine the threshold + threshold = singular_values[0] * epsilon + # Count singular values greater than the threshold + num_rank = torch.sum(singular_values > threshold).item() + + value = int(num_rank) + + return Metric( + mean_std=MeanStd( + mean=value, + std=None), + ) + diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index e1733df1..3ebc4409 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -9,48 +9,46 @@ from dash import dcc, html from dash.dependencies import Input, Output, State from mergekit.metric_methods.all_metrics import Layer -from mergekit.metric_methods.base import Results +from mergekit.metric_methods.base import Results, PlotType +from mergekit.common import ModelReference global_colours_list = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] +global_shapes_list = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] class ResultsHandler: + # Can accept many intra-model results, but only one inter-model result!! (X) """ - Object to handle metrics results. Allows for easy plotting of metrics by layer and across layers. - - Input: - Use the load_metrics method to load the metrics into the handler. - metrics: List of tasks and their metrics. This is the output of the run_measure function in mergekit.measure. - - Attributes: - all_stats: Dictionary of recorded statistics for each layer. e.g. {'layer_name': {'cosine_similarity_mean': 0.5, 'cosine_similarity_std': 0.1}} - metric_names: List of names of all statistics available. e.g. ['cosine_similarity_mean', 'cosine_similarity_std'] - layer_names: List of layer names. - - Methods: - load_metrics: Load the metrics into the handler. - # stats_at_layer: Get the metrics for a specific layer. - # info_at_layer: Get the weight info for a specific layer. - line_plot: Plot a line plot of the chosen stat across layers. - plotly_layer_histogram: Plot a histogram of the stat for a specific layer. + """ def __init__(self): - self.results = Results() - - def load_metrics(self, metrics: List[Tuple[Task, Layer]]): - self.metric_names = [] - for task, metric in metrics: - if metric is not None: - self.results.add_layer(metric, name=task.weight_info.name) - self.metric_names.extend(list(metric.metrics.keys())) - self.layer_names = list(self.results.layers.keys()) - self.metric_names = list(set(self.metric_names)) + self.intra_model_results: Dict[ModelReference, Results] = {} + self.inter_model_results: Results = None + self.available_layer_plots = { + 'mean_std': [], + 'histogram': [], + 'heatmap': [], + 'scatter_plot': [] + } def load_results(self, results: Results): - self.results = results - self.layer_names = list(self.results.layers.keys()) - self.metric_names = list(set([metric for layer in self.results.layers.values() for metric in layer.metrics.keys()])) + results.finalise() + if len(results.model_refs) == 2: + self.inter_model_results = results + elif len(results.model_refs) == 1: + self.intra_model_results[results.model_refs[0]] = results + else: + raise ValueError("Results should have either 1 or 2 model_refs") + + for plot_type in self.available_layer_plots.keys(): + self.available_layer_plots[plot_type] = list(self.inter_model_results.available_plot_types(plot_type).keys()) + for model_ref, results in self.intra_model_results.items(): + self.available_layer_plots[plot_type] += list(results.available_plot_types(plot_type).keys()) + + self.available_layer_plots[plot_type] = list(set(self.available_layer_plots[plot_type])) - def categorise_layers(self, layer_names): + self.all_results = list(self.intra_model_results.values()) + [self.inter_model_results] + + def categorise_layers(self, layer_names) -> List[str]: # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config categories = [] for name in layer_names: @@ -65,32 +63,31 @@ def categorise_layers(self, layer_names): return categories def plotly_line_plots(self, metric_name:str): - if metric_name not in self.metric_names: - print(f"Stat {metric_name} not found") - return [], [] + # if metric_name not in self.metric_names: + # print(f"Stat {metric_name} not found") + # return [], [] - layer_names = self.layer_names - means, stds, model_refs = self.results.get_lineplot_data(metric_name) traces = [] - available_shapes = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] + if metric_name in self.inter_model_results.available_plot_types('line_plot'): # bring if case into loop? (X) + layer_names = self.inter_model_results.layer_names + means, stds = self.inter_model_results.get_lineplot_data(metric_name) + categorised_layers = self.categorise_layers(layer_names) # Different category for each layer type + unique_categories = list(set(categorised_layers)) + traces = self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories) - if len(model_refs) > 1: - unique_categories = [str(ref) for ref in model_refs] - layer_categories = [[str(model_refs[i])]*len(layer_names) for i in range(len(model_refs))] else: - layer_categories = [self.categorise_layers(layer_names)] - unique_categories = list(set(layer_categories[0])) - for i, model_ref in enumerate(model_refs): - traces.extend(self._plotly_line_plot(layer_names, means[i], stds[i], layer_categories[i], unique_categories, shape=available_shapes[i%len(available_shapes)])) - + unique_categories = list(self.intra_model_results.keys()) + for i, (model_ref, results) in enumerate(self.intra_model_results.items()): + layer_names = results.layer_names + means, stds = results.get_lineplot_data(metric_name) + categorised_layers = [model_ref]*len(layer_names) # Different category for each model, every layer in each model has the same category + shape = global_shapes_list[i%len(global_shapes_list)] + traces.extend(self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories, shape)) + return traces, layer_names - def _plotly_line_plot(self, x_values, means, stds, layer_categories, unique_categories, shape:str='circle', **kwargs): + def _plotly_line_plot(self, x_values, means, stds, categorised_layers, unique_categories, shape:str='circle', **kwargs): """ - Plot the stat values across layers using Plotly. - - Args: - stat (str): The name of the stat to plot. Returns: List[go.Scatter]: List of Plotly Scatter objects. @@ -105,8 +102,8 @@ def _plotly_line_plot(self, x_values, means, stds, layer_categories, unique_cate traces = [] for category in unique_categories: - y_category = [means[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] - std_category = [stds[i] if layer_categories[i] == category else None for i in range(len(self.layer_names))] + y_category = [means[i] if categorised_layers[i] == category else None for i in range(len(categorised_layers))] + std_category = [stds[i] if categorised_layers[i] == category else None for i in range(len(categorised_layers))] if all([y is None for y in y_category]): continue @@ -119,32 +116,95 @@ def _plotly_line_plot(self, x_values, means, stds, layer_categories, unique_cate visible=True ), mode='markers', - name=category, + name=str(category), marker=dict(color=category_styles[category]), marker_symbol=shape )) return traces - def plotly_layer_heatmap(self, layer_name:str, metric_name:str): - """ - Plot the stat values as a heatmap. - - Args: - layer_name (str): The name of the layer. - metric_name (str): The name of the stat to plot. - Returns: - go.Heatmap: Plotly Heatmap object. - """ - metrics_list = self.results.layers[layer_name].metrics[metric_name] - if len(metrics_list) > 1: - raise Warning(f"Multiple heatmaps found for {metric_name} at layer {layer_name}. Using the first one.") + def plotly_layer_plot(self, layer_name:str, metric_name:str, plot_type:str): + assert plot_type in [p.value for p in PlotType], f"Plot type {plot_type} not in {[p.value for p in PlotType]}" + data = [] + + for result in self.all_results: + valid_metrics = result.available_plot_types(plot_type) + if metric_name in valid_metrics.keys(): + if layer_name in valid_metrics[metric_name]: + data.append(getattr(result.layers[layer_name].metrics[metric_name], plot_type)) - heatmap = self.results.layers[layer_name].metrics[metric_name][0].heatmap.data + return self.get_traces(data, plot_type) # Can prob use type of data to determine plot type (X) + + def get_traces(self, data:List, plot_type): + if plot_type == PlotType.HEATMAP.value: + traces = [go.Heatmap( + z=d.data, + colorscale='RdBu' + ) for d in data] + elif plot_type == PlotType.SCATTER_PLOT.value: + traces = [go.Scatter( + x = d.x, + y = d.y + ) for d in data] + elif plot_type == PlotType.HISTOGRAM.value: + traces = [] + for i, d in enumerate(data): + count, edges, widths = d.count, d.edges, d.widths + traces.append(go.Bar( + x=edges[:-1], + y=count, + width=widths, + marker=dict( + color=global_colours_list[i], + opacity=0.75, + line=dict( + color='black', + width=1 + ) + ))) + else: + raise ValueError(f'{plot_type} not valid for layer specific plot') - return [go.Heatmap( - z=heatmap, - colorscale='RdBu' - )] + return traces + + + # def plotly_layer_heatmap(self, layer_name:str, metric_name:str): + # """ + # Plot the stat values as a heatmap. + + # Args: + # layer_name (str): The name of the layer. + # metric_name (str): The name of the stat to plot. + # Returns: + # go.Heatmap: Plotly Heatmap object. + # """ + # heatmaps = [] + # if metric_name in self.inter_model_results.available_plot_types('histogram').keys() \ + # and layer_name in self.inter_model_results.available_plot_types('histogram')[metric_name]: + # heatmaps.append(self.inter_model_results.layers[layer_name].metrics[metric_name].heatmap.data) + # else: + # for model_ref, results in self.intra_model_results.items(): + # if metric_name in results.available_plot_types('heatmap'): + # heatmaps.append(results.layers[layer_name].metrics[metric_name].heatmap.data) + + # return [go.Heatmap( + # z=data, + # colorscale='RdBu' + # ) for data in heatmaps] + + # def plotly_scatter_plot(self, layer_name:str, metric_name:str): + # scatters = [] + # if metric_name in self.inter_model_results.available_plot_types('scatter_plot').keys() \ + # and layer_name in self.inter_model_results.available_plot_types('scatter_plot')[metric_name]: + # scatters.append(self.inter_model_results.layers[layer_name].metrics[metric_name].scatter_plot) + # else: + # for model_ref, results in self.intra_model_results.items(): + # if metric_name in results.available_plot_types('scatter_plot'): + # scatters.append(results.layers[layer_name].metrics[metric_name].scatter_plot) + + # return [go.Scatter( + # x = scatter.x, + # y = scatter.y + # ) for scatter in scatters] def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): """ @@ -166,39 +226,53 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): if kwarg in kwargs: getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) - def plotly_layer_histogram(self, layer_name: str, metric_name: str): - metric_list = self.results.layers[layer_name].metrics[metric_name] - - traces = [] - for i, metric in enumerate(metric_list): - hist = metric.histogram - count, edges, widths = hist.count, hist.edges, hist.widths - traces.append(go.Bar( - x=edges[:-1], - y=count, - width=widths, - marker=dict( - color=global_colours_list[i], - opacity=0.75, - line=dict( - color='black', - width=1 - ) - ), - name=str(metric.model_ref) - )) - return traces + # def plotly_layer_histogram(self, layer_name: str, metric_name: str): + # histograms = [] + # if metric_name in self.inter_model_results.available_plot_types('histogram'): + # histograms.append(self.inter_model_results.layers[layer_name].metrics[metric_name]) + # else: + # for model_ref, results in self.intra_model_results.items(): + # if metric_name in results.available_plot_types('histogram'): + + # metric_list = self.results.layers[layer_name].metrics[metric_name] + + # traces = [] + # for i, metric in enumerate(metric_list): + # hist = metric.histogram + # count, edges, widths = hist.count, hist.edges, hist.widths + # traces.append(go.Bar( + # x=edges[:-1], + # y=count, + # width=widths, + # marker=dict( + # color=global_colours_list[i], + # opacity=0.75, + # line=dict( + # color='black', + # width=1 + # ) + # ), + # # name=str(metric.model_ref) + # )) + # return traces def layer_plot_options(self, layer_name: str): - layer = self.results.layers[layer_name] - - return [ - {"label": f"{metric.title()} Histogram", "value": [metric, 'histogram']} - for metric in layer.metrics_with_attribute('histogram') - ] + [ - {"label": f"{metric.title()} Heatmap", "value": [metric, 'heatmap']} - for metric in layer.metrics_with_attribute('heatmap') - ] + metric_options = [] + for plot_type in PlotType: + if plot_type == PlotType.MEAN_STD: + continue + metric_options.extend([ + {"label": f"{metric.title()} {plot_type.value}", "value": [metric, plot_type.value]} + for metric in self.inter_model_results.layers[layer_name].metrics_with_attribute(plot_type.value)] + ) + for result in self.all_results: + metric_options.extend([ + {"label": f"{metric.title()} {plot_type.value}", "value": [metric, plot_type.value]} + for metric in result.layers[layer_name].metrics_with_attribute(plot_type.value)] + ) + break # Assuming all intra-model results have the same metrics + return metric_options + def create_app(results_handler): app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) @@ -206,8 +280,8 @@ def create_app(results_handler): app.layout = html.Div([ create_header(), create_line_plot_section(results_handler), - create_layer_metrics_section(), - create_heatmap_section(results_handler) + create_single_layer_section(), + create_across_layers_section(results_handler) ]) register_callbacks(app, results_handler) @@ -223,14 +297,14 @@ def create_line_plot_section(results_handler): dcc.Dropdown( id='line-plot-dropdown', options=[{'label': metric_name.replace('_', ' ').title(), 'value': metric_name} - for metric_name in results_handler.metric_names], + for metric_name in results_handler.available_layer_plots['mean_std']], value='cosine_similarity', style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} ), dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}) ], className='container-fluid') -def create_layer_metrics_section(): +def create_single_layer_section(): return html.Div([ html.H3('Layer Metrics', style={'textAlign': 'center'}), dcc.Dropdown( @@ -242,23 +316,22 @@ def create_layer_metrics_section(): dcc.Graph(id='layer-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}) ], className='container-fluid') -def create_heatmap_section(results_handler): - if hasattr(results_handler.results, 'others') and isinstance(results_handler.results.others, dict): - heatmap_sections = [] - for i, (key, value) in enumerate(results_handler.results.others.items()): - model_name = key.split('Representations')[1] - model_name = model_name.split('.')[0].replace('_', ' ').replace('-',' ').title() - metric = key.split('||')[-1].replace('_', ' ').replace('-',' ').title() +def create_across_layers_section(results_handler): + results = list(results_handler.intra_model_results.values()) + [results_handler.inter_model_results] + + plot_sections = [] - title = f'{model_name} - {metric}' + for result in results: + if hasattr(result, 'across_layer_metrics'): + for metric_name, metric in result.across_layer_metrics.items(): + for attr in ['histogram', 'heatmap', 'scatter_plot']: + if hasattr(metric, attr): + plot_sections.append(html.Div([ + html.H3(f'{attr+metric_name.replace("_", " ").title()} {attr.replace("_", " ").title()}', style={'textAlign': 'center'}), + dcc.Graph(id=f'{attr}-plot-{metric_name}', style={'width': '50%', 'height': '50%', 'position': 'relative'}) + ], className='container-fluid')) - heatmap_sections.append(html.Div([ - html.H3(f'Heatmap: {title}', style={'textAlign': 'center'}), - dcc.Graph(id=f'heatmap-plot-{i}', style={'width': '50%', 'height': '50%', 'position': 'relative'}) - ], className='container-fluid')) - return html.Div(heatmap_sections) - else: - return html.Div() + return html.Div(plot_sections) def default_option(options, current_value): if not options: @@ -277,7 +350,7 @@ def register_callbacks(app, results_handler): Input('line-plot', 'clickData'), Input('line-plot-dropdown', 'value') ) - def update_metric_dropdown_options(clickData, selected_metric): + def update_metric_dropdown_options(clickData, selected_metric): # What distinguishes these options from layer-specific options? if not clickData: return [], None @@ -306,6 +379,13 @@ def display_layer_data(selected_metric, clickData): if not selected_metric: selected_metric = results_handler.layer_plot_options(layer_name)[0]['value'] + + # metric_options = [] + # for result in results_handler.all_results: + # metric_options.extend(result.available_metrics_at_layer(layer_name)) + # metric_options = list(set(metric_options)) + # selected_metric = metric_options[0] + metric_name, plot_type = selected_metric # Define default axis titles @@ -318,13 +398,15 @@ def display_layer_data(selected_metric, clickData): yaxis_title = "Model 0 Head" plot_function = { - 'histogram': results_handler.plotly_layer_histogram, - 'heatmap': results_handler.plotly_layer_heatmap - }.get(plot_type.lower(), - lambda *args, **kwargs: go.Figure()) # Defaults to *function* to produce empty figure + 'histogram': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.HISTOGRAM.value), + 'heatmap': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.HEATMAP.value), + 'scatter_plot': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.SCATTER_PLOT.value) + } - traces = plot_function(layer_name=layer_name, - metric_name=metric_name) + traces = plot_function[plot_type.lower()]( + layer_name=layer_name, + metric_name=metric_name + ) return create_figure(traces=traces, title=f"{plot_type.title()} for {layer_name} | {metric_name}", @@ -347,7 +429,8 @@ def update_line_plot(selected_metric): traces, layer_names = results_handler.plotly_line_plots(metric_name=selected_metric) fig = go.Figure() for trace in traces: - fig.add_trace(trace) + if trace: + fig.add_trace(trace) fig.update_layout( title=f"{selected_metric.replace('_', ' ').title()} Across Layers", @@ -359,35 +442,63 @@ def update_line_plot(selected_metric): yaxis=dict(title=selected_metric.replace('_', ' ').title()) ) return fig - # Dynamically create callbacks for each heatmap plot - if hasattr(results_handler.results, 'others') and isinstance(results_handler.results.others, dict): - for i, (key, value) in enumerate(results_handler.results.others.items()): - if isinstance(value.data, (list, np.ndarray)): # Assuming heatmap data is in array-like format - @app.callback( - Output(f'heatmap-plot-{i}', 'figure'), - Input(f'heatmap-plot-{i}', 'id') # Dummy input to trigger the callback on load - ) - def update_heatmap_plot(_key=key): - key = list(results_handler.results.others.keys())[int(_key.split('-')[-1])] - heatmap_data = results_handler.results.others[key].data - fig = go.Figure(data=go.Heatmap( - z=heatmap_data, - colorscale='Viridis', # Using Viridis colormap - zmin=np.nanmin(heatmap_data), # Set the scale min to the min data value - zmax=np.nanmax(heatmap_data), # Set the scale max to the max data value - colorbar=dict(title='Scale') # Customize the color bar - )) - default_layout_options = { - 'xaxis_title':"X Axis", - 'yaxis_title':"Y Axis" - } - if results_handler.results.others[key].update_layout_options: - default_layout_options.update(results_handler.results.others[key].update_layout_options) - fig.update_layout( - title=f"Heatmap: {_key}", - **default_layout_options - ) - return fig + # # Dynamically create callbacks for each heatmap plot + # if hasattr(results_handler.results, 'others') and isinstance(results_handler.results.others, dict): + # for i, (key, value) in enumerate(results_handler.results.others.items()): + # if isinstance(value.data, (list, np.ndarray)): # Assuming heatmap data is in array-like format + # @app.callback( + # Output(f'heatmap-plot-{i}', 'figure'), + # Input(f'heatmap-plot-{i}', 'id') # Dummy input to trigger the callback on load + # ) + # def update_heatmap_plot(_key=key): + # key = list(results_handler.results.others.keys())[int(_key.split('-')[-1])] + # heatmap_data = results_handler.results.others[key].data + # fig = go.Figure(data=go.Heatmap( + # z=heatmap_data, + # colorscale='Viridis', # Using Viridis colormap + # zmin=np.nanmin(heatmap_data), # Set the scale min to the min data value + # zmax=np.nanmax(heatmap_data), # Set the scale max to the max data value + # colorbar=dict(title='Scale') # Customize the color bar + # )) + # default_layout_options = { + # 'xaxis_title':"X Axis", + # 'yaxis_title':"Y Axis" + # } + # if results_handler.results.others[key].update_layout_options: + # default_layout_options.update(results_handler.results.others[key].update_layout_options) + # fig.update_layout( + # title=f"Heatmap: {_key}", + # **default_layout_options + # ) + # return fig + + for result in results_handler.all_results: + if hasattr(result, 'across_layer_metrics'): + for metric_name, metric in result.across_layer_metrics.items(): + for attr in ['histogram', 'heatmap', 'scatter_plot']: + if hasattr(metric, attr): + id=f'{attr}-plot-{metric_name}' + + @app.callback( + Output(id, 'figure'), + Input(id, 'id') + ) + def update_across_layers_plot(_id=id): + metric_name = _id.split('-')[-1] + plot_function = { + 'histogram': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.HISTOGRAM), + 'heatmap': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.HEATMAP), + 'scatter_plot': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.SCATTER_PLOT) + }.get(attr, + lambda *args, **kwargs: go.Figure()) + traces = plot_function(metric_name=metric_name) + + return create_figure(traces=traces, + title=f"{id} | {metric_name}", + # xaxis_title=xaxis_title, + # yaxis_title=yaxis_title + ) + def create_figure(traces, title, xaxis_title, yaxis_title): fig = go.Figure() diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py index 624af005..6139d6f3 100644 --- a/mergekit/scripts/run_metrics.py +++ b/mergekit/scripts/run_metrics.py @@ -6,6 +6,13 @@ from mergekit.merge import MergeOptions from mergekit.merge import run_merge from mergekit.plot_tools.plot_tools import create_app, ResultsHandler +from mergekit.metric_methods.base import Results + +def create_temp_config(config_yml, **kwargs): + with open(config_yml, "r", encoding="utf-8") as config: + config = yaml.safe_load(config) + config.update(kwargs) + return MergeConfiguration.model_validate(config) @click.command() @click.option('--output_path', default="./merged", help='folder to store the result in.') @@ -14,22 +21,58 @@ @click.option('--lazy_unpickle', default=False, help='experimental low-memory model loader.') @click.option('--low_cpu_memory', default=False, help='enable if you somehow have more VRAM than RAM+swap') def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory): - with open(config_yml, "r", encoding="utf-8") as fp: - metric_config = MergeConfiguration.model_validate(yaml.safe_load(fp)) - - metrics_results = run_merge( - metric_config, - out_path=output_path, - options=MergeOptions( - cuda=torch.cuda.is_available(), - copy_tokenizer=copy_tokenizer, - lazy_unpickle=lazy_unpickle, - low_cpu_memory=low_cpu_memory, - ), - ) + with open(config_yml, "r", encoding="utf-8") as config: + config = yaml.safe_load(config) + metric_config = MergeConfiguration.model_validate(config) + + models = metric_config.models + intra_model = config['parameters']['intra_model_metrics'] + inter_model = config['parameters']['inter_model_metrics'] + + intra_results = {} + inter_results = None + if intra_model: + print(f"Running intra-model metrics for {len(models)} models: {models}") + for model in models: + temp_config = create_temp_config(config_yml, models=[{'model':model.model.model.path}], parameters=({'intra_model_metrics':True, 'inter_model_metrics':False})) + print(f" {model}") + metrics_out = run_merge( + temp_config, + out_path=output_path, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=copy_tokenizer, + lazy_unpickle=lazy_unpickle, + low_cpu_memory=low_cpu_memory, + ), + ) + intra_results[model.model] = Results().load_metrics(metrics_out, model_refs=[model.model]) + + if inter_model: + assert len(models) == 2, "Inter-model metrics require exactly 2 models" + print(f"Running inter-model metrics for {models}") + temp_config = create_temp_config(config_yml, parameters=({'intra_model_metrics':False, 'inter_model_metrics':True})) + + print(f" {models}") + metrics_out = run_merge( + temp_config, + out_path=output_path, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=copy_tokenizer, + lazy_unpickle=lazy_unpickle, + low_cpu_memory=low_cpu_memory, + ), + ) + inter_results = Results().load_metrics(metrics_out, model_refs=models) + handler = ResultsHandler() - handler.load_metrics(metrics_results) + + handler.load_results(inter_results) + for result in intra_results.values(): + handler.load_results(result) + app = create_app(results_handler=handler) app.run_server() diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 4cb264f6..8750449f 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -8,7 +8,7 @@ from tqdm import tqdm import torch.nn.functional as F -from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer, ScatterPlot from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap @@ -139,6 +139,84 @@ def aggregate(self) -> Metric: def clear(self) -> None: pass +import math +class _CKA(object): + # Class from https://github.com/jayroxis/CKA-similarity/blob/main/CKA.py + def __init__(self): + pass + + def centering(self, K): + n = K.shape[0] + unit = np.ones([n, n]) + I = np.eye(n) + H = I - unit / n + return np.dot(np.dot(H, K), H) + + def rbf(self, X, sigma=None): + GX = np.dot(X, X.T) + KX = np.diag(GX) - GX + (np.diag(GX) - GX).T + if sigma is None: + mdist = np.median(KX[KX != 0]) + sigma = math.sqrt(mdist) + KX *= -0.5 / (sigma * sigma) + KX = np.exp(KX) + return KX + + def kernel_HSIC(self, X, Y, sigma): + return np.sum( + self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)) + ) + + def linear_HSIC(self, X, Y): + L_X = X @ X.T + L_Y = Y @ Y.T + return np.sum(self.centering(L_X) * self.centering(L_Y)) + + def linear_CKA(self, X, Y): + hsic = self.linear_HSIC(X, Y) + var1 = np.sqrt(self.linear_HSIC(X, X)) + var2 = np.sqrt(self.linear_HSIC(Y, Y)) + + return hsic / (var1 * var2) + + def kernel_CKA(self, X, Y, sigma=None): + hsic = self.kernel_HSIC(X, Y, sigma) + var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) + var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) + + return hsic / (var1 * var2) + +class CKA(MetricAggregator): + def __init__(self, device: str = "cpu"): + self.device = device + self.cka = _CKA() + + def process_dataset(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + self.result = self.cka.linear_CKA(batch_a.cpu().numpy(), batch_b.cpu().numpy()) + + def aggregate(self) -> Metric: + return Metric(mean_std=MeanStd(mean=self.result)) + + +from sklearn.manifold import TSNE + +class t_SNE(MetricAggregator): + def __init__(self, device: str = "cpu"): + self.device = device + self.tsne = TSNE(n_components=2, random_state=42) + + def process_dataset(self, data: torch.Tensor) -> None: + self.result = self.tsne.fit_transform(data.cpu().numpy()) + + def aggregate(self) -> Metric: + return Metric( + scatter_plot=ScatterPlot( + x=self.result[:, 0], + y=self.result[:, 1], + ) + ) + + class LayerByIndex: def __init__(self, reps_path: str): self.reps_path = reps_path @@ -251,7 +329,9 @@ def results_list_to_heatmap(all_results, metric_names:List[str]) -> dict: METRICS_TABLE = { 'cosine_similarity': Cosine_Similarity, 'mse': MSE, - 'linearity_score': Linearity_Score + 'linearity_score': Linearity_Score, + 'cka': CKA, + 't-sne': t_SNE } @click.command() @@ -292,9 +372,36 @@ def main(config_yml: str): heatmaps = results_list_to_heatmap(results_list, metric_names=[metric.__name__.lower() for metric in metric_classes]) for metric_name, heatmap in heatmaps.items(): - all_results.others[reps_path + '||' + metric_name] = heatmap + all_results.across_layer_metircs[reps_path + '||' + metric_name] = heatmap # Address this - new implementation only ever has one model per results object + + if config['analyse_individually']: + results = Results() + for reps_path in tqdm(model_paths, desc='Model', leave=False, total=len(model_paths), initial = 1): + with LayerByIndex(reps_path) as reps: + for i, layer in enumerate(tqdm(reps, desc='Layer', leave=False, initial = 1)): + layer_name = f'Layer_{i}' + layer_results = Layer(WeightInfo(name=layer_name)) + for metric_name, metric_class in use_metrics.items(): + metric = metric_class(device=device) + # Want automatic choice of metrics according to whether it requires batches or single data. + # For now only accept metrics that require single data. + if not hasattr(metric, 'process_dataset'): + print(f'{metric_name} does not support dataset processing') + continue + collect_batches = [] + for batch in tqdm(layer, desc='Batch', leave=False, initial = 1): + batch = torch.tensor(layer[batch][:]).to(device) + collect_batches.append(batch) + if len(collect_batches) == 32: + single_data = torch.cat(collect_batches) # check dimensionality + metric.process_dataset(single_data) + continue + layer_results.add_metric(metric.aggregate(), metric_name) + results.add_layer(layer_results, name=layer_name) + results.save(f'results.pkl') + # all_results.save('results.pkl') + - all_results.save('results.pkl') if __name__ == '__main__': main() \ No newline at end of file diff --git a/representations/representations.py b/representations/representations.py new file mode 100644 index 00000000..00051352 --- /dev/null +++ b/representations/representations.py @@ -0,0 +1,156 @@ +#%% +import torch +import h5py +import numpy as np +import click +import yaml +from pathlib import Path +from typing import List, Dict, Any, Optional +from tqdm import tqdm +import torch.nn.functional as F + +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer +from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap + + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.merge_methods.base import MergeMethod + +import enum + +class LayerByIndex: + def __init__(self, reps_path: str): + self.reps_path = reps_path + self.representations = None + self.layers = None + + def __enter__(self): + self.representations = h5py.File(self.reps_path, 'r') + self.layers = list(self.representations.keys()) + return self + + def __exit__(self, *args, **kwargs): + if self.representations: + self.representations.close() + + def __getitem__(self, idx: int): + return self.representations[self.layers[idx]] + + def __len__(self) -> int: + return len(self.layers) + + def __iter__(self): + return iter(self.representations[layer] for layer in self.layers) + + +class ModelAnalysisType(enum.Enum): + INDIVIDUAL = "individual" + COMPARISON = "comparison" + + +class LayerComparisonType(enum.Enum): + SINGLE = "single" # Layer i + BLOCK = "block" # Layer i in model 1 and layer i+(block size) in model 1 + CORRESPONDING_LAYERS = "corresponding_layers" # Layer i in model 1 and layer i in model 2 + ALL_LAYERS = "all_layers" # Layer i in model 1 and Layer j in model (1 or 2) + + +class MetricInput(enum.Enum): + ONE_SHOT = "one_shot" + BATCHES = "batches" + + +def valid_experiment(analysis_type, comparison_type, metric_input): + if comparison_type == LayerComparisonType.ALL_LAYERS: + raise ValueError("Comparison type 'all_layers' is not supported") + if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + raise ValueError("Comparison type 'corresponding_layers' only supported for comparison analysis") + + +def layer_loader(representation_path): + with LayerByIndex(representation_path) as representations: + for layer in tqdm(representations, desc='Analysing Layer', + total=len(representations), leave=False, initial = 1): + yield layer + +def batch_loader(layer, device): + for batch in tqdm(layer, desc='processing batch', + total=len(layer), leave=False, initial = 1): + yield torch.tensor(layer[batch][:], device=device) + +# Experiment Loops +def single(representation_path: str): + for layer_idx, layer in enumerate(layer_loader(representation_path)): + for batch in batch_loader(layer, device="cpu"): + yield batch, layer_idx + +def block(representation_path: str, block_size: int, device: str = "cpu"): + with LayerByIndex(representation_path) as reps: + for layer_idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {block_size}-block, Block Start at Layer', + total=len(reps) - block_size, leave=False, initial = 1): + if layer_idx + block_size >= len(reps): + break + + block_end = reps[layer_idx + block_size] + + for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', + total=len(block_start), leave=False, initial = 1): + batch_0 = torch.tensor(block_start[batch_0][:]).to(device) + batch_1 = torch.tensor(block_end[batch_1][:]).to(device) + yield (batch_0, batch_1), layer_idx + +def corresponding_layers(representation_path_0: str, representation_path_1: str, device: str = "cpu"): + with LayerByIndex(representation_path_0) as reps_0, LayerByIndex(representation_path_1) as reps_1: + for layer_idx, (layer_0, layer_1) in enumerate(tqdm(zip(reps_0, reps_1), desc='Comparing Corresponding Layers', + total=len(reps_0), leave=False, initial = 1)): + for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', + total=len(layer_0), leave=False, initial = 1): + batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) + batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) + yield (batch_0, batch_1), layer_idx + +def all_layers(representation_path_0: str, representation_path_1: str, device: str = "cpu"): + with LayerByIndex(representation_path_0) as reps_0, LayerByIndex(representation_path_1) as reps_1: + for layer_0_idx, layer_0 in enumerate(tqdm(reps_0, desc='Model 0 Layers', + total=len(reps_0), leave=False, initial = 1)): + for layer_1_idx, layer_1 in enumerate(tqdm(reps_1, desc='Model 1 Layers', + total=len(reps_1), leave=False, initial = 1)): + for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', + total=len(layer_0), leave=False, initial = 1): + batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) + batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) + + yield (batch_0, batch_1), (layer_0_idx, layer_1_idx) + + +def main(): + representation_paths = [Path("/Users/elliotstein/Documents/Arcee/mergekit/representations/Representations_Qwen_Qwen2-7B-Instruct_microsoft_orca-math-word-problems-200k_4000.h5"), + Path("/Users/elliotstein/Documents/Arcee/mergekit/representations/Representations_arcee-ai_qwen2-7b-math-tess_microsoft_orca-math-word-problems-200k_4000.h5") + ] + + analysis_type = ModelAnalysisType.INDIVIDUAL + comparison_type = LayerComparisonType.SINGLE + metric_input = MetricInput.BATCHES + + valid_experiment(analysis_type, comparison_type, metric_input) + + for data, layer_idx in single(representation_paths[0]): + pass + + for data, layer_idx in block(representation_paths[0], 2): + pass + + for data, layer_idx in corresponding_layers(representation_paths[0], representation_paths[1]): + pass + + for data, layer_idx in all_layers(representation_paths[0], representation_paths[1]): + pass + +#%% + +if __name__ == "__main__": + main() + From 60459804483505023cb0ce156dbd2c5d65b04885 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 30 Jul 2024 11:07:04 +0100 Subject: [PATCH 50/64] address alphanumeric representation layer name ordering issue --- representations/representation_metrics.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 8750449f..d14637fe 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -11,6 +11,8 @@ from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer, ScatterPlot from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap +import math +from sklearn.manifold import TSNE from mergekit.architecture import WeightInfo from mergekit.common import ModelReference @@ -139,7 +141,6 @@ def aggregate(self) -> Metric: def clear(self) -> None: pass -import math class _CKA(object): # Class from https://github.com/jayroxis/CKA-similarity/blob/main/CKA.py def __init__(self): @@ -197,9 +198,6 @@ def process_dataset(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: def aggregate(self) -> Metric: return Metric(mean_std=MeanStd(mean=self.result)) - -from sklearn.manifold import TSNE - class t_SNE(MetricAggregator): def __init__(self, device: str = "cpu"): self.device = device @@ -216,7 +214,6 @@ def aggregate(self) -> Metric: ) ) - class LayerByIndex: def __init__(self, reps_path: str): self.reps_path = reps_path @@ -379,7 +376,7 @@ def main(config_yml: str): for reps_path in tqdm(model_paths, desc='Model', leave=False, total=len(model_paths), initial = 1): with LayerByIndex(reps_path) as reps: for i, layer in enumerate(tqdm(reps, desc='Layer', leave=False, initial = 1)): - layer_name = f'Layer_{i}' + layer_name = f'Layer_{i:03d}' layer_results = Layer(WeightInfo(name=layer_name)) for metric_name, metric_class in use_metrics.items(): metric = metric_class(device=device) From cb5e31a923140160e2e95108f0fff2446945ddc0 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 30 Jul 2024 11:10:13 +0100 Subject: [PATCH 51/64] MAJOR RESTRUCTURE of results, results handler and representation metrics --- mergekit/metric_methods/aggregator_metrics.py | 281 ++++++++++ mergekit/metric_methods/base.py | 48 +- mergekit/plot_tools/plot_tools.py | 209 ++------ representations/representation_metrics.py | 503 +++++++----------- .../visualise_representation_results.py | 15 +- 5 files changed, 542 insertions(+), 514 deletions(-) create mode 100644 mergekit/metric_methods/aggregator_metrics.py diff --git a/mergekit/metric_methods/aggregator_metrics.py b/mergekit/metric_methods/aggregator_metrics.py new file mode 100644 index 00000000..a1f76e19 --- /dev/null +++ b/mergekit/metric_methods/aggregator_metrics.py @@ -0,0 +1,281 @@ +import torch +import numpy as np +from typing import List, Dict, Any, Optional +from tqdm import tqdm +import torch.nn.functional as F +import math + +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer, ScatterPlot +from mergekit.metric_methods.metrics import compute_histogram + +from mergekit.architecture import WeightInfo +from sklearn.manifold import TSNE + +import enum + +class ModelAnalysisType(enum.Enum): + INDIVIDUAL = "individual" + COMPARISON = "comparison" + +class LayerComparisonType(enum.Enum): + SINGLE = "single" # Analyse Layer i + BLOCK = "block" # Compare Layer i in model 1 with layer i+(block size) in model 1 + CORRESPONDING_LAYERS = "corresponding" # Compare Layer i in model 1 with layer i in model 2 + ALL_LAYERS = "all_layers" # Compare Layer i in model 1 with Layer j in model (1 or 2) + +class MetricAggregator(): + def __init__(self, device: str = "cpu"): + self.device = device + self.valid_for = { + LayerComparisonType.SINGLE.value: False, + LayerComparisonType.BLOCK.value: False, + LayerComparisonType.CORRESPONDING_LAYERS.value: False, + LayerComparisonType.ALL_LAYERS.value: False + } + + def process_batch(self, batch_a: torch.Tensor, batch_b: Optional[torch.Tensor]) -> None: + raise NotImplementedError + + def aggregate(self) -> Metric: + raise NotImplementedError + + def clear(self) -> None: + raise NotImplementedError + +class Cosine_Similarity(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.cosine_similarities = torch.tensor([], device=self.device) + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING_LAYERS.value: True, + LayerComparisonType.ALL_LAYERS.value: True + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_similarities = F.cosine_similarity(batch_a, batch_b, dim=1) + self.cosine_similarities = torch.cat((self.cosine_similarities, batch_similarities)) + + def aggregate(self) -> Metric: + hist = compute_histogram(self.cosine_similarities, 100) + mean_std=MeanStd( + mean=self.cosine_similarities.mean().item(), + std=self.cosine_similarities.std().item() + ) + histogram=Histogram( + count=hist[0], + edges=hist[1], + widths=hist[2] + ) + self.__init__() + return Metric( + histogram=histogram, + mean_std=mean_std + ) + + def clear(self) -> None: + self.cosine_similarities = torch.tensor([], device=self.device) + +class MSE(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.square_errors = torch.tensor([], device=self.device) + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING_LAYERS.value: True, + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_square_errors = torch.square(batch_a - batch_b).flatten() + self.square_errors = torch.cat((self.square_errors, batch_square_errors)) + + def aggregate(self) -> Metric: + hist = compute_histogram(self.square_errors, 100) + mean_std=MeanStd( + mean=self.square_errors.mean().item(), + std=self.square_errors.std().item() + ) + histogram=Histogram( + count=hist[0], + edges=hist[1], + widths=hist[2] + ) + self.__init__() + return Metric( + histogram=histogram, + mean_std=mean_std + ) + +class Linearity_Score(MetricAggregator): + def __init__(self, device: str = "cpu"): + + super().__init__(device=device) + self.iterations = 0 + self.max_iterations = 5 + self.A = None + self.optimiser = None + self.initialised = False + self.done = False + self.losses = [] + + self.absolute_square_sum = 0 + self.num_elements = 0 + + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True + }) + + def _first_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_size, dimension = batch_a.size() + + self.A = torch.empty(dimension, dimension, device=self.device) + torch.nn.init.normal_(self.A) + self.A = torch.nn.Parameter(self.A) + + self.optimiser = torch.optim.SGD([self.A], lr=0.0001) + self.initialised = True + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_a = batch_a / torch.norm(batch_a, dim=1, keepdim=True) + batch_b = batch_b / torch.norm(batch_b, dim=1, keepdim=True) # Check dimensionality (X) + if not self.initialised: + self._first_batch(batch_a, batch_b) + if self.done: # stop training A and evaluate + residuals = batch_a @ self.A - batch_b + self.absolute_square_sum += torch.abs(residuals).sum().item() + self.num_elements += residuals.numel() + + else: + + loss = torch.norm(batch_a @ self.A - batch_b) ** 2 + loss.backward() + self.losses.append(loss.item()) + print(f'Loss: {loss.item()}') + self.optimiser.step() + + self.iterations += 1 + + if self.iterations >= self.max_iterations: + self.done = True + + def aggregate(self) -> Metric: + linearity_score = 1 - self.absolute_square_sum / self.num_elements + self.__init__() + return Metric(mean_std=MeanStd(mean=linearity_score)) + + def clear(self) -> None: + pass + +class _CKA(object): + # Class from https://github.com/jayroxis/CKA-similarity/blob/main/CKA.py + def __init__(self): + pass + + def centering(self, K): + n = K.shape[0] + unit = np.ones([n, n]) + I = np.eye(n) + H = I - unit / n + return np.dot(np.dot(H, K), H) + + def rbf(self, X, sigma=None): + GX = np.dot(X, X.T) + KX = np.diag(GX) - GX + (np.diag(GX) - GX).T + if sigma is None: + mdist = np.median(KX[KX != 0]) + sigma = math.sqrt(mdist) + KX *= -0.5 / (sigma * sigma) + KX = np.exp(KX) + return KX + + def kernel_HSIC(self, X, Y, sigma): + return np.sum( + self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)) + ) + + def linear_HSIC(self, X, Y): + L_X = X @ X.T + L_Y = Y @ Y.T + return np.sum(self.centering(L_X) * self.centering(L_Y)) + + def linear_CKA(self, X, Y): + hsic = self.linear_HSIC(X, Y) + var1 = np.sqrt(self.linear_HSIC(X, X)) + var2 = np.sqrt(self.linear_HSIC(Y, Y)) + + return hsic / (var1 * var2) + + def kernel_CKA(self, X, Y, sigma=None): + hsic = self.kernel_HSIC(X, Y, sigma) + var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) + var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) + + return hsic / (var1 * var2) + +class CKA(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.cka = _CKA() + self.batches_a = [] + self.batches_b = [] + self.stop = False + self.max_batches = 10 + + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING_LAYERS.value: True, + LayerComparisonType.ALL_LAYERS.value: True + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + if not self.stop: + self.batches_a.append(batch_a.cpu().numpy()) + self.batches_b.append(batch_b.cpu().numpy()) + + if len(self.batches_a) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + self.result = self.cka.linear_CKA(np.concatenate(self.batches_a), + np.concatenate(self.batches_b)) + return Metric(mean_std=MeanStd(mean=self.result)) + +class t_SNE(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.tsne = TSNE(n_components=2, random_state=42) + self.batches = [] + self.max_batches = 5 + self.stop = False + + self.valid_for.update({ + LayerComparisonType.SINGLE.value: True, + }) + + def process_batch(self, batch: torch.Tensor) -> None: + if not self.stop: + self.batches.append(batch.cpu().numpy()) + + if len(self.batches) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + data = np.concatenate(self.batches) + self.result = self.tsne.fit_transform(data) + + metric = Metric( + scatter_plot=ScatterPlot( + x=self.result[:, 0], + y=self.result[:, 1], + ) + ) + self.__init__(self.device) # Reset ready for next layer + return metric + +METRICS_TABLE = { + 'cosine_similarity': Cosine_Similarity, + 'mse': MSE, + 'linearity_score': Linearity_Score, + 'cka': CKA, + 't-sne': t_SNE + } \ No newline at end of file diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 48dd7cf2..7a912158 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -29,28 +29,10 @@ class MetricMethod(MergeMethod): pass # Structure of the results object -# OLD -# Results -# └── layers: Dict[str, Layer] -# └── Layer -# ├── weight_info: WeightInfo -# └── metrics: Dict[str, List[Metric]] -# └── Metric -# ├── histogram: Optional[Histogram] -# ├── mean_std: Optional[MeanStd] -# ├── heatmap: Optional[Heatmap] -# └── model_ref: Optional[ModelReference] - -# Each Layer stores metrics under a key (e.g. 'cosine_similarity') in a dictionary. -# The values stored under each key are a **list** of Metric objects. This is to allow for a single metric type to be computed for each model. - # For metrics which compare between models, (e.g. cosine similarity) the list will contain a single Metric object storing the comparison data. - # For metrics which analyse individual models, (e.g. intrinsic dimension) the list will contain a Metric object for each model. - -# New # Results -# ├── model_ref: Optional[List[ModelReference]] # One for individual model, two for comparison +# ├── model_path: Optional[List[str]] # One for individual model, two for comparison # └── layers: Dict[str, Layer] # └── Layer # ├── weight_info: WeightInfo (remove?) @@ -117,10 +99,6 @@ def add_metric(self, metric: Metric, name: str): else: raise ValueError(f"Metric with name {name} already exists in layer {self.weight_info.layer_name}.") - # def add_metric_list(self, metric_list: List[Metric], name: str): - # for metric in metric_list: - # self.add_metric(metric, name) - def expand_to_fit(all_layer_names: List[str], values: List[float], subset_layer_names: List[str]) -> List[float]: """ Expands a list of values to fit a larger list of layer names, filling in missing values with None. @@ -150,16 +128,14 @@ class Results: def __init__(self): self.layers: Dict[str, Layer] = {} self.across_layer_metrics: Dict[str, Metric] = {} - self.model_refs: Optional[List[ModelReference]] = None + self.model_paths: Optional[List[str]] = None def add_layer(self, layer: Layer, name: str): if name not in self.layers.keys(): self.layers[name] = layer - # def get_metric(self, layer_name: str, metric_name: str) -> Metric: - # return self.get_layer(layer_name, metric_name) # Doesnt' Work! (X) - def load_metrics(self, metrics: List[Tuple[Task, Layer]], model_refs: Optional[List[ModelReference]] = None): - self.model_refs = model_refs + def load_metrics(self, metrics: List[Tuple[Task, Layer]], model_paths: Optional[List[str]] = None): + self.model_paths = model_paths for task, metric in metrics: if metric is not None: self.add_layer(metric, name=task.weight_info.name) @@ -169,22 +145,13 @@ def get_lineplot_data(self, metric_name: str): means, stds = [],[] available_line_plots = self.available_plot_types(PlotType.MEAN_STD.value) - assert metric_name in available_line_plots, f"Metric {metric_name} does not have mean/std data available." + if metric_name not in available_line_plots: + return [], [] layers_with_data = available_line_plots[metric_name] means = [self.layers[layer].metrics[metric_name].mean_std.mean for layer in layers_with_data] stds = [self.layers[layer].metrics[metric_name].mean_std.std for layer in layers_with_data] - # for name, layer in self.layers.items(): - # if metric_name in layer.metrics: - # # for model_result in layer.metrics[metric_name]: - # # model_ref = model_result.model_refs if model_result.model_refs else 'all' - # means.append(layer.metrics[metric_name].mean_std.mean) - # stds.append(layer.metrics[metric_name].mean_std.std) - # layers.append(name) - - # # means_list, stds_list, model_references = list(means.values()), list(stds.values()), list(means.keys()) - # # for i, model_ref in enumerate(model_references): means = expand_to_fit(all_layer_names=list(self.layers.keys()), values=means, subset_layer_names=layers_with_data) stds = expand_to_fit(all_layer_names=list(self.layers.keys()), values=stds, subset_layer_names=layers_with_data) @@ -241,7 +208,6 @@ def print_metric_summary(self): print("Available Metrics Summary:") for metric, info in metric_info.items(): print(f"\nMetric: {metric}") - # print(f" Available in layers: {', '.join(info['layers'])}") print(f" Has mean/std: {'Yes' if info[PlotType.MEAN_STD.value] else 'No'}") print(f" Has histogram: {'Yes' if info[PlotType.HISTOGRAM.value] else 'No'}") print(f" Has heatmap: {'Yes' if info[PlotType.HEATMAP.value] else 'No'}") @@ -260,7 +226,7 @@ def save(self, path: str): pickle.dump(self, f) def load(self, path: str): - path_obj = Path(path) + path_obj = Path(path).resolve() if path_obj.exists() and path_obj.is_file(): with open(path_obj, 'rb') as f: results = pickle.load(f) diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 3ebc4409..38f93546 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -16,10 +16,6 @@ global_shapes_list = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] class ResultsHandler: - # Can accept many intra-model results, but only one inter-model result!! (X) - """ - - """ def __init__(self): self.intra_model_results: Dict[ModelReference, Results] = {} self.inter_model_results: Results = None @@ -32,16 +28,21 @@ def __init__(self): def load_results(self, results: Results): results.finalise() - if len(results.model_refs) == 2: + if len(results.model_paths) == 2: self.inter_model_results = results - elif len(results.model_refs) == 1: - self.intra_model_results[results.model_refs[0]] = results + elif len(results.model_paths) == 1: + # key = results.model_paths[0] + key = len(self.intra_model_results) + self.intra_model_results[key] = results else: - raise ValueError("Results should have either 1 or 2 model_refs") + raise ValueError("Results should have either 1 or 2 model_paths") for plot_type in self.available_layer_plots.keys(): - self.available_layer_plots[plot_type] = list(self.inter_model_results.available_plot_types(plot_type).keys()) - for model_ref, results in self.intra_model_results.items(): + + # if self.inter_model_results is not None: + self.available_layer_plots[plot_type] += list(self.inter_model_results.available_plot_types(plot_type).keys()) + # if self.inter_model_results is not None: + for model_path, results in self.intra_model_results.items(): self.available_layer_plots[plot_type] += list(results.available_plot_types(plot_type).keys()) self.available_layer_plots[plot_type] = list(set(self.available_layer_plots[plot_type])) @@ -63,10 +64,6 @@ def categorise_layers(self, layer_names) -> List[str]: return categories def plotly_line_plots(self, metric_name:str): - # if metric_name not in self.metric_names: - # print(f"Stat {metric_name} not found") - # return [], [] - traces = [] if metric_name in self.inter_model_results.available_plot_types('line_plot'): # bring if case into loop? (X) layer_names = self.inter_model_results.layer_names @@ -77,12 +74,13 @@ def plotly_line_plots(self, metric_name:str): else: unique_categories = list(self.intra_model_results.keys()) - for i, (model_ref, results) in enumerate(self.intra_model_results.items()): + for i, (model_path, results) in enumerate(self.intra_model_results.items()): layer_names = results.layer_names means, stds = results.get_lineplot_data(metric_name) - categorised_layers = [model_ref]*len(layer_names) # Different category for each model, every layer in each model has the same category - shape = global_shapes_list[i%len(global_shapes_list)] - traces.extend(self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories, shape)) + if means: + categorised_layers = [model_path]*len(layer_names) # Different category for each model, every layer in each model has the same category + shape = global_shapes_list[i%len(global_shapes_list)] + traces.extend(self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories, shape)) return traces, layer_names @@ -143,7 +141,8 @@ def get_traces(self, data:List, plot_type): elif plot_type == PlotType.SCATTER_PLOT.value: traces = [go.Scatter( x = d.x, - y = d.y + y = d.y, + mode='markers' ) for d in data] elif plot_type == PlotType.HISTOGRAM.value: traces = [] @@ -166,46 +165,6 @@ def get_traces(self, data:List, plot_type): return traces - - # def plotly_layer_heatmap(self, layer_name:str, metric_name:str): - # """ - # Plot the stat values as a heatmap. - - # Args: - # layer_name (str): The name of the layer. - # metric_name (str): The name of the stat to plot. - # Returns: - # go.Heatmap: Plotly Heatmap object. - # """ - # heatmaps = [] - # if metric_name in self.inter_model_results.available_plot_types('histogram').keys() \ - # and layer_name in self.inter_model_results.available_plot_types('histogram')[metric_name]: - # heatmaps.append(self.inter_model_results.layers[layer_name].metrics[metric_name].heatmap.data) - # else: - # for model_ref, results in self.intra_model_results.items(): - # if metric_name in results.available_plot_types('heatmap'): - # heatmaps.append(results.layers[layer_name].metrics[metric_name].heatmap.data) - - # return [go.Heatmap( - # z=data, - # colorscale='RdBu' - # ) for data in heatmaps] - - # def plotly_scatter_plot(self, layer_name:str, metric_name:str): - # scatters = [] - # if metric_name in self.inter_model_results.available_plot_types('scatter_plot').keys() \ - # and layer_name in self.inter_model_results.available_plot_types('scatter_plot')[metric_name]: - # scatters.append(self.inter_model_results.layers[layer_name].metrics[metric_name].scatter_plot) - # else: - # for model_ref, results in self.intra_model_results.items(): - # if metric_name in results.available_plot_types('scatter_plot'): - # scatters.append(results.layers[layer_name].metrics[metric_name].scatter_plot) - - # return [go.Scatter( - # x = scatter.x, - # y = scatter.y - # ) for scatter in scatters] - def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): """ Set the attributes of the plot. @@ -225,37 +184,7 @@ def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): for kwarg in ax_kwargs: if kwarg in kwargs: getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) - - # def plotly_layer_histogram(self, layer_name: str, metric_name: str): - # histograms = [] - # if metric_name in self.inter_model_results.available_plot_types('histogram'): - # histograms.append(self.inter_model_results.layers[layer_name].metrics[metric_name]) - # else: - # for model_ref, results in self.intra_model_results.items(): - # if metric_name in results.available_plot_types('histogram'): - - # metric_list = self.results.layers[layer_name].metrics[metric_name] - - # traces = [] - # for i, metric in enumerate(metric_list): - # hist = metric.histogram - # count, edges, widths = hist.count, hist.edges, hist.widths - # traces.append(go.Bar( - # x=edges[:-1], - # y=count, - # width=widths, - # marker=dict( - # color=global_colours_list[i], - # opacity=0.75, - # line=dict( - # color='black', - # width=1 - # ) - # ), - # # name=str(metric.model_ref) - # )) - # return traces - + def layer_plot_options(self, layer_name: str): metric_options = [] for plot_type in PlotType: @@ -379,23 +308,14 @@ def display_layer_data(selected_metric, clickData): if not selected_metric: selected_metric = results_handler.layer_plot_options(layer_name)[0]['value'] - - # metric_options = [] - # for result in results_handler.all_results: - # metric_options.extend(result.available_metrics_at_layer(layer_name)) - # metric_options = list(set(metric_options)) - # selected_metric = metric_options[0] - metric_name, plot_type = selected_metric - # Define default axis titles - xaxis_title = "Value" - yaxis_title = "Count" - - # Update axis titles if plot_type is 'heatmap' - if plot_type.lower() == "heatmap": - xaxis_title = "Model 1 Head" - yaxis_title = "Model 0 Head" + if plot_type.lower() in ["heatmap", "scatter_plot"]: + xaxis_title = "Model 1" + yaxis_title = "Model 0" + elif plot_type.lower() == 'histogram': + xaxis_title = "Value" + yaxis_title = "Count" plot_function = { 'histogram': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.HISTOGRAM.value), @@ -411,7 +331,8 @@ def display_layer_data(selected_metric, clickData): return create_figure(traces=traces, title=f"{plot_type.title()} for {layer_name} | {metric_name}", xaxis_title=xaxis_title, - yaxis_title=yaxis_title + yaxis_title=yaxis_title, + plot_type=plot_type ) except (KeyError, IndexError, AttributeError) as e: @@ -442,42 +363,13 @@ def update_line_plot(selected_metric): yaxis=dict(title=selected_metric.replace('_', ' ').title()) ) return fig - # # Dynamically create callbacks for each heatmap plot - # if hasattr(results_handler.results, 'others') and isinstance(results_handler.results.others, dict): - # for i, (key, value) in enumerate(results_handler.results.others.items()): - # if isinstance(value.data, (list, np.ndarray)): # Assuming heatmap data is in array-like format - # @app.callback( - # Output(f'heatmap-plot-{i}', 'figure'), - # Input(f'heatmap-plot-{i}', 'id') # Dummy input to trigger the callback on load - # ) - # def update_heatmap_plot(_key=key): - # key = list(results_handler.results.others.keys())[int(_key.split('-')[-1])] - # heatmap_data = results_handler.results.others[key].data - # fig = go.Figure(data=go.Heatmap( - # z=heatmap_data, - # colorscale='Viridis', # Using Viridis colormap - # zmin=np.nanmin(heatmap_data), # Set the scale min to the min data value - # zmax=np.nanmax(heatmap_data), # Set the scale max to the max data value - # colorbar=dict(title='Scale') # Customize the color bar - # )) - # default_layout_options = { - # 'xaxis_title':"X Axis", - # 'yaxis_title':"Y Axis" - # } - # if results_handler.results.others[key].update_layout_options: - # default_layout_options.update(results_handler.results.others[key].update_layout_options) - # fig.update_layout( - # title=f"Heatmap: {_key}", - # **default_layout_options - # ) - # return fig for result in results_handler.all_results: if hasattr(result, 'across_layer_metrics'): for metric_name, metric in result.across_layer_metrics.items(): - for attr in ['histogram', 'heatmap', 'scatter_plot']: - if hasattr(metric, attr): - id=f'{attr}-plot-{metric_name}' + for plot_type in ['histogram', 'heatmap', 'scatter_plot']: + if hasattr(metric, plot_type): + id=f'{plot_type}-plot-{metric_name}' @app.callback( Output(id, 'figure'), @@ -489,26 +381,39 @@ def update_across_layers_plot(_id=id): 'histogram': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.HISTOGRAM), 'heatmap': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.HEATMAP), 'scatter_plot': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.SCATTER_PLOT) - }.get(attr, + }.get(plot_type, lambda *args, **kwargs: go.Figure()) traces = plot_function(metric_name=metric_name) return create_figure(traces=traces, title=f"{id} | {metric_name}", - # xaxis_title=xaxis_title, - # yaxis_title=yaxis_title + plot_type = plot_type ) - -def create_figure(traces, title, xaxis_title, yaxis_title): - fig = go.Figure() - for trace in traces: - fig.add_trace(trace) - - fig.update_layout( - title=title, - xaxis=dict(title=xaxis_title), - yaxis=dict(title=yaxis_title) - ) +from plotly.subplots import make_subplots +def create_figure(traces, title, xaxis_title, yaxis_title, plot_type): + if plot_type in ["scatter_plot", "heatmap"]: + num_plots = len(traces) + num_cols = 2 + num_rows = (num_plots + 1) // num_cols + + fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[f"Plot {i+1}" for i in range(num_plots)]) + + for i, trace in enumerate(traces): + row = (i // num_cols) + 1 + col = (i % num_cols) + 1 + fig.add_trace(trace, row=row, col=col) + fig.update_xaxes(title_text=xaxis_title, row=row, col=col) + fig.update_yaxes(title_text=yaxis_title, row=row, col=col) + else: + fig = go.Figure() + for trace in traces: + fig.add_trace(trace) + + fig.update_layout( + title=title, + xaxis=dict(title=xaxis_title), + yaxis=dict(title=yaxis_title) + ) return fig diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index d14637fe..4e26ecb9 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -1,3 +1,4 @@ +#%% import torch import h5py import numpy as np @@ -6,214 +7,12 @@ from pathlib import Path from typing import List, Dict, Any, Optional from tqdm import tqdm -import torch.nn.functional as F -from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer, ScatterPlot -from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap - -import math -from sklearn.manifold import TSNE +from mergekit.metric_methods.base import Results, Layer +from mergekit.metric_methods.aggregator_metrics import ModelAnalysisType, LayerComparisonType, METRICS_TABLE from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference -from mergekit.merge_methods.base import MergeMethod - -class MetricAggregator: - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: - raise NotImplementedError - - def aggregate(self) -> Metric: - raise NotImplementedError - - def clear(self) -> None: - raise NotImplementedError - -class Cosine_Similarity(MetricAggregator): - def __init__(self, device: str = "cpu"): - self.device = device - self.cosine_similarities = torch.tensor([], device=self.device) - - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: - batch_similarities = F.cosine_similarity(batch_a, batch_b, dim=1) - self.cosine_similarities = torch.cat((self.cosine_similarities, batch_similarities)) - - def aggregate(self) -> Metric: - hist = compute_histogram(self.cosine_similarities, 100) - mean_std=MeanStd( - mean=self.cosine_similarities.mean().item(), - std=self.cosine_similarities.std().item() - ) - histogram=Histogram( - count=hist[0], - edges=hist[1], - widths=hist[2] - ) - self.clear() - return Metric( - histogram=histogram, - mean_std=mean_std - ) - - def clear(self) -> None: - self.cosine_similarities = torch.tensor([], device=self.device) - -class MSE(MetricAggregator): - def __init__(self, device: str = "cpu"): - self.device = device - self.square_errors = torch.tensor([], device=self.device) - - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: - batch_square_errors = torch.square(batch_a - batch_b).flatten() - self.square_errors = torch.cat((self.square_errors, batch_square_errors)) - - def aggregate(self) -> Metric: - hist = compute_histogram(self.square_errors, 100) - mean_std=MeanStd( - mean=self.square_errors.mean().item(), - std=self.square_errors.std().item() - ) - histogram=Histogram( - count=hist[0], - edges=hist[1], - widths=hist[2] - ) - self.clear() - return Metric( - histogram=histogram, - mean_std=mean_std - ) - - def clear(self) -> None: - self.square_errors = torch.tensor([], device=self.device) - -class Linearity_Score(MetricAggregator): - def __init__(self, device: str = "cpu"): - self.device = device - self.iterations = 0 - self.max_iterations = 250 - self.A = None - self.optimiser = None - self.initialised = False - self.done = False - self.losses = [] - - self.absolute_square_sum = 0 - self.num_elements = 0 - - def _first_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: - batch_size, dimension = batch_a.size() - - self.A = torch.empty(dimension, dimension, device=self.device) - torch.nn.init.normal_(self.A) - self.A = torch.nn.Parameter(self.A) - - self.optimiser = torch.optim.SGD([self.A], lr=0.0001) - self.initialised = True - - def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: - batch_a = batch_a / torch.norm(batch_a, dim=1, keepdim=True) - batch_b = batch_b / torch.norm(batch_b, dim=1, keepdim=True) # Check dimensionality (X) - if not self.initialised: - self._first_batch(batch_a, batch_b) - if self.done: # stop training A and evaluate - residuals = batch_a @ self.A - batch_b - self.absolute_square_sum += torch.abs(residuals).sum().item() - self.num_elements += residuals.numel() - - else: - - loss = torch.norm(batch_a @ self.A - batch_b) ** 2 - loss.backward() - self.losses.append(loss.item()) - print(f'Loss: {loss.item()}') - self.optimiser.step() - - self.iterations += 1 - - if self.iterations >= self.max_iterations: - self.done = True - - def aggregate(self) -> Metric: - linearity_score = 1 - self.absolute_square_sum / self.num_elements - self.clear() - return Metric(mean_std=MeanStd(mean=linearity_score)) - - def clear(self) -> None: - pass - -class _CKA(object): - # Class from https://github.com/jayroxis/CKA-similarity/blob/main/CKA.py - def __init__(self): - pass - - def centering(self, K): - n = K.shape[0] - unit = np.ones([n, n]) - I = np.eye(n) - H = I - unit / n - return np.dot(np.dot(H, K), H) - - def rbf(self, X, sigma=None): - GX = np.dot(X, X.T) - KX = np.diag(GX) - GX + (np.diag(GX) - GX).T - if sigma is None: - mdist = np.median(KX[KX != 0]) - sigma = math.sqrt(mdist) - KX *= -0.5 / (sigma * sigma) - KX = np.exp(KX) - return KX - - def kernel_HSIC(self, X, Y, sigma): - return np.sum( - self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)) - ) - - def linear_HSIC(self, X, Y): - L_X = X @ X.T - L_Y = Y @ Y.T - return np.sum(self.centering(L_X) * self.centering(L_Y)) - - def linear_CKA(self, X, Y): - hsic = self.linear_HSIC(X, Y) - var1 = np.sqrt(self.linear_HSIC(X, X)) - var2 = np.sqrt(self.linear_HSIC(Y, Y)) - - return hsic / (var1 * var2) - - def kernel_CKA(self, X, Y, sigma=None): - hsic = self.kernel_HSIC(X, Y, sigma) - var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) - var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) - - return hsic / (var1 * var2) - -class CKA(MetricAggregator): - def __init__(self, device: str = "cpu"): - self.device = device - self.cka = _CKA() - - def process_dataset(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: - self.result = self.cka.linear_CKA(batch_a.cpu().numpy(), batch_b.cpu().numpy()) - - def aggregate(self) -> Metric: - return Metric(mean_std=MeanStd(mean=self.result)) - -class t_SNE(MetricAggregator): - def __init__(self, device: str = "cpu"): - self.device = device - self.tsne = TSNE(n_components=2, random_state=42) - - def process_dataset(self, data: torch.Tensor) -> None: - self.result = self.tsne.fit_transform(data.cpu().numpy()) - - def aggregate(self) -> Metric: - return Metric( - scatter_plot=ScatterPlot( - x=self.result[:, 0], - y=self.result[:, 1], - ) - ) - + class LayerByIndex: def __init__(self, reps_path: str): self.reps_path = reps_path @@ -238,13 +37,49 @@ def __len__(self) -> int: def __iter__(self): return iter(self.representations[layer] for layer in self.layers) -def compare_representations(reps_path_a: str, reps_path_b: str, - metrics_classes: Dict[str, MetricAggregator], device: str, results: Results) -> Dict[str, Any]: - if results is None: +def valid_experiment(analysis_type, comparison_type): + if comparison_type == LayerComparisonType.ALL_LAYERS: + raise ValueError("Comparison type 'all_layers' is not supported") + if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") + +# Experiment Loops +def single(representations_path, metric_classes, results, device='cpu'): + if not results: results = Results() + with LayerByIndex(representations_path) as representations: + for layer in tqdm(representations, desc='Analysing Layer', + total=len(representations), leave=False, initial = 1): + layer_name = layer.name.split('/')[-1] + num = f"{int(layer_name.split('_')[-1]):03d}" + layer_name = f"Layer {num}" + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch in tqdm(layer, desc='Batch', + total=len(layer), leave=False, initial = 1): + batch = torch.tensor(layer[batch][:], device=device) + + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch) + + layer_results = Layer(WeightInfo(name=layer_name)) + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + # metric.clear() + + results.add_layer(layer_results, layer_name) + + return results - with LayerByIndex(reps_path_a) as representations_a, \ - LayerByIndex(reps_path_b) as representations_b: +def corresponding(representations_path_0, representation_path_1, metric_classes, results, device='cpu'): + if not results: + results = Results() + with LayerByIndex(representations_path_0) as representations_a, \ + LayerByIndex(representation_path_1) as representations_b: for layer_a, layer_b in tqdm(zip(representations_a, representations_b), desc='Comparing Representations at layer', @@ -254,8 +89,10 @@ def compare_representations(reps_path_a: str, reps_path_b: str, layer_b_name = layer_b.name.split('/')[-1] if layer_a_name != layer_b_name: raise ValueError(f'Layer mismatch: {layer_a_name} != {layer_b_name}') + num = f"{int(layer_a_name.split('_')[-1]):03d}" + layer_name = f"Layer {num}" - metrics = [metric_class(device=device) for metric_class in metrics_classes.values()] + metrics = [metric_class(device=device) for metric_class in metric_classes] for batch_a, batch_b in tqdm(zip(layer_a, layer_b), desc='Batch', total=len(layer_a), leave=False, initial = 1): @@ -266,28 +103,28 @@ def compare_representations(reps_path_a: str, reps_path_b: str, for metric in metrics: metric.process_batch(batch_a, batch_b) - layer_results = Layer(WeightInfo(name=layer_a_name)) + layer_results = Layer(WeightInfo(name=layer_name)) # Aggregate over the batches and add to the layer results for metric in metrics: layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - metric.clear() + # metric.clear() - results.add_layer(layer_results, layer_a_name) - - return results + results.add_layer(layer_results, layer_name) + return results -def compute_skip_block_metrics(reps_path: str, skip_layers: int, - metric_classes: List[MetricAggregator], device: str) -> Results: - results = Results() - with LayerByIndex(reps_path) as reps: - for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {skip_layers}-block, Block Start at Layer', - total=len(reps) - skip_layers, leave=False, initial = 1): - if idx + skip_layers >= len(reps): +def block(representations_path, block_size, metric_classes, results, device='cpu'): + if not results: + results = Results() + out = {metric().__class__.__name__.lower(): [] for metric in metric_classes} + with LayerByIndex(representations_path) as reps: + for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {block_size}-block, Block Start at Layer', + total=len(reps) - block_size, leave=False, initial = 1): + if idx + block_size >= len(reps): break # Create metrics metrics = [metric_class(device=device) for metric_class in metric_classes] - block_end = reps[idx + skip_layers] + block_end = reps[idx + block_size] for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', total=len(block_start), leave=False, initial = 1): @@ -298,107 +135,145 @@ def compute_skip_block_metrics(reps_path: str, skip_layers: int, metric.process_batch(batch_0, batch_1) # Aggregate metrics and add to results - layer_results = Layer(WeightInfo(name=f"Layer {idx}")) + layer_results = Layer(WeightInfo(name=f"Block {idx} size {block_size}")) for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + # layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + out[metric.__class__.__name__.lower()].append(metric.aggregate()) - results.add_layer(layer_results, f"Layer {idx}") + results.add_layer(layer_results, f"Block {idx} size {block_size}") + for metric in out: + out[metric] = np.array(out[metric]) + return out - return results +def all_layers(representations_path_0, representations_path_1, metric_classes, results, device='cpu'): + if not results: + results = Results() + with LayerByIndex(representations_path_0) as reps_0, LayerByIndex(representations_path_1) as reps_1: + for idx_0, layer_0 in enumerate(tqdm(reps_0, desc='Model 0 Layers', + total=len(reps_0), leave=False, initial = 1)): + for idx_1, layer_1 in enumerate(tqdm(reps_1, desc='Model 1 Layers', + total=len(reps_1), leave=False, initial = 1)): + if len(layer_0) != len(layer_1): + raise ValueError(f'Layer mismatch: {len(layer_0)} != {len(layer_1)}') + + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', + total=len(layer_0), leave=False, initial = 1): + batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) + batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) + + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + layer_results = Layer(WeightInfo(name=f"Layer {idx_0} - Layer {idx_1}")) + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + + results.add_layer(layer_results, f"Layer {idx_0} - Layer {idx_1}") -def results_list_to_heatmap(all_results, metric_names:List[str]) -> dict: - rows = len(all_results) - cols = max([len(result.layers) for result in all_results]) - heatmaps = {} - for metric_name in metric_names: - heatmap = np.full((rows, cols), np.nan) - - for i, result in enumerate(all_results): - for j, layer in enumerate(result.layers): - heatmap[i, j] = result.layers[layer].metrics[metric_name][0].mean_std.mean - heatmaps[metric_name] = Heatmap(data=heatmap, - update_layout_options = { - 'xaxis_title': 'Layer Number', - 'yaxis_title': 'Block Size', - }) - return heatmaps - -METRICS_TABLE = { - 'cosine_similarity': Cosine_Similarity, - 'mse': MSE, - 'linearity_score': Linearity_Score, - 'cka': CKA, - 't-sne': t_SNE -} + return results @click.command() -@click.option('--config_yml', default="./representations/config.yml", help='Merge configuration file.') -def main(config_yml: str): +@click.option('--config_yml', default="./representations/config.yml", help='path to the configuration file.') +def main(config_yml: str = "config.yml"): with open(config_yml, "r", encoding="utf-8") as fp: config = yaml.safe_load(fp) - model_paths = config['representation_paths'] + + representation_paths = [Path(model_path) for model_path in config['representation_paths']] metrics_toggle = config['metrics'] - skip_layers = config['block_analysis_parameters']['skip_layers'] + analysis_type = ModelAnalysisType(config['analysis_type']) + comparison_type = LayerComparisonType(config['comparison_type']) + + out_dir = representation_paths[0].parent.parent / 'stored_results' + + + + device = 'cuda' if torch.cuda.is_available() else \ + 'mps' if torch.backends.mps.is_available() else \ + 'cpu' + + + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in metrics_toggle.items() + if enabled} + - use_metrics = {name: METRICS_TABLE[name] for name, enabled in metrics_toggle.items() if enabled} + valid_experiment(analysis_type, comparison_type) - device = torch.device("cuda" if torch.cuda.is_available() else - "mps" if torch.backends.mps.is_available() else "cpu") + final_results = [] + if analysis_type == ModelAnalysisType.INDIVIDUAL: + out_paths = [out_dir / f"{str(rep).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" for rep in representation_paths] + for out_path in out_paths: + assert not Path(out_path).exists(), f'{out_path} already exists.' + for representation_path in representation_paths: + individual_results = Results() + individual_results.model_paths = [representation_path] + + if not representation_path.exists(): + raise FileNotFoundError(f"Representation file {representation_path} not found") + + if comparison_type == LayerComparisonType.SINGLE: + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] + individual_results = single(representation_path, + metrics, + results=individual_results, + device=device) + + if comparison_type == LayerComparisonType.BLOCK: + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] + heatmaps = {} + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = np.array([]) + for block_size in range(1, 9): + block_res = block(representations_path=representation_path, + block_size=block_size, + metric_classes=metrics, + results=individual_results, + device=device) + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = np.append(heatmaps[metric().__class__.__name__.lower()], block_res[metric().__class__.__name__.lower()]) + + for metric in metrics: + result.across_layer_metrics[metric.__class__.__name__.lower()] = heatmaps[metric.__class__.__name__.lower()] # Definitely a simpler way to code this (X) + + if comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] + individual_results = corresponding(representations_path_0=representation_path, + representation_path_1=representation_path, + metric_classes=metrics, + results=individual_results, + device=device) + + final_results.append(individual_results) - for path in model_paths: - if not Path(path).exists(): - raise FileNotFoundError(f"File not found: {path}") - - all_results = Results() - if config['compare_between_models']: - if len(model_paths) != 2: - raise ValueError("Expected 2 model paths for comparison") - - all_results = compare_representations(model_paths[0], model_paths[1], - metrics_classes=use_metrics, device=device, results=all_results) - - if config['block_analysis']: - for reps_path in tqdm(model_paths, desc='Model', leave=False, total=len(model_paths), initial = 1): - results_list = [] - metric_classes = list(use_metrics.values()) - for skip_layer in tqdm(skip_layers, desc='Skip Layers', initial = 1): - results_list.append( - compute_skip_block_metrics(reps_path, skip_layer, metric_classes=metric_classes, device=device) - ) + if analysis_type == ModelAnalysisType.COMPARISON: + out_paths = [out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in representation_paths])}+{str(list(use_metrics.keys()))}.json"] + assert not Path(out_path).exists(), f'{out_path} already exists.' + + comparison_results = Results() + comparison_results.model_paths = representation_paths + + if comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] + comparison_results = corresponding(representations_path_0=representation_paths[0], + representation_path_1=representation_paths[1], + metric_classes=metrics, + results=comparison_results, + device=device) + if comparison_type == LayerComparisonType.ALL_LAYERS: + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL_LAYERS.value]] + comparison_results = all_layers(representations_path_0=representation_paths[0], + representations_path_1=representation_paths[1], + metric_classes=metrics, + results=comparison_results, + device=device) + final_results.append(comparison_results) + + for result, out_path in zip(final_results, out_paths): + result.save(out_path) #SAVE AS WE GO, NOT ALL AT THE END (X) - heatmaps = results_list_to_heatmap(results_list, metric_names=[metric.__name__.lower() for metric in metric_classes]) - for metric_name, heatmap in heatmaps.items(): - all_results.across_layer_metircs[reps_path + '||' + metric_name] = heatmap # Address this - new implementation only ever has one model per results object - - if config['analyse_individually']: - results = Results() - for reps_path in tqdm(model_paths, desc='Model', leave=False, total=len(model_paths), initial = 1): - with LayerByIndex(reps_path) as reps: - for i, layer in enumerate(tqdm(reps, desc='Layer', leave=False, initial = 1)): - layer_name = f'Layer_{i:03d}' - layer_results = Layer(WeightInfo(name=layer_name)) - for metric_name, metric_class in use_metrics.items(): - metric = metric_class(device=device) - # Want automatic choice of metrics according to whether it requires batches or single data. - # For now only accept metrics that require single data. - if not hasattr(metric, 'process_dataset'): - print(f'{metric_name} does not support dataset processing') - continue - collect_batches = [] - for batch in tqdm(layer, desc='Batch', leave=False, initial = 1): - batch = torch.tensor(layer[batch][:]).to(device) - collect_batches.append(batch) - if len(collect_batches) == 32: - single_data = torch.cat(collect_batches) # check dimensionality - metric.process_dataset(single_data) - continue - layer_results.add_metric(metric.aggregate(), metric_name) - results.add_layer(layer_results, name=layer_name) - results.save(f'results.pkl') - # all_results.save('results.pkl') - - - -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py index 8e96c749..7569fe55 100644 --- a/representations/visualise_representation_results.py +++ b/representations/visualise_representation_results.py @@ -3,15 +3,16 @@ from mergekit.metric_methods.base import Results @click.command() -@click.option('--results_path', - default="./representations/results.pkl", +@click.option('--input_dir', + default="/Users/elliotstein/Documents/Arcee/mergekit/representations/results_to_visualise", help="path to load the results from.") -def main(results_path): - results = Results() - results = results.load(results_path) - +def main(input_dir): handler = ResultsHandler() - handler.load_results(results) + for res in Path(input_dir).iterdir(): + results = Results() + results = results.load(res.absolute()) + + handler.load_results(results) app = create_app(results_handler=handler) app.run_server() From dfdcb0023820c9ca4e4c1ecb49a3589586f3dc7c Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 30 Jul 2024 11:11:05 +0100 Subject: [PATCH 52/64] tidy up --- .gitignore | 4 ++++ mergekit/scripts/run_metrics.py | 4 ++-- representations/config.yml | 17 +++++++------- representations/store_representations.py | 28 +++++++----------------- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/.gitignore b/.gitignore index 0f35a0d2..4c24f09f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +*.pkl +*.h5 +offload_folder/ + # Environment mergekit/bin/ diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py index 6139d6f3..4529edf3 100644 --- a/mergekit/scripts/run_metrics.py +++ b/mergekit/scripts/run_metrics.py @@ -46,7 +46,7 @@ def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory) low_cpu_memory=low_cpu_memory, ), ) - intra_results[model.model] = Results().load_metrics(metrics_out, model_refs=[model.model]) + intra_results[model.model] = Results().load_metrics(metrics_out, model_paths=[model.model.model.path]) if inter_model: assert len(models) == 2, "Inter-model metrics require exactly 2 models" @@ -64,7 +64,7 @@ def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory) low_cpu_memory=low_cpu_memory, ), ) - inter_results = Results().load_metrics(metrics_out, model_refs=models) + inter_results = Results().load_metrics(metrics_out, model_paths=models) handler = ResultsHandler() diff --git a/representations/config.yml b/representations/config.yml index 9c60ca59..2677b614 100644 --- a/representations/config.yml +++ b/representations/config.yml @@ -1,14 +1,13 @@ representation_paths: -- ./representations/Representations_BEE-spoke-data_smol_llama-220M-GQA_train_4000.h5 -- ./representations/Representations_BEE-spoke-data_smol_llama-220M-openhermes_train_4000.h5 +- "/Users/elliotstein/Documents/Arcee/mergekit/representations/stored_representations/Representations_Qwen_Qwen2-7B-Instruct_microsoft_orca-math-word-problems-200k_4000.h5" +- "/Users/elliotstein/Documents/Arcee/mergekit/representations/stored_representations/Representations_arcee-ai_qwen2-7b-math-tess_microsoft_orca-math-word-problems-200k_4000.h5" metrics: cosine_similarity: true - mse: true - linearity_score: true + mse: false + linearity_score: false + cka: false + t-sne: false -compare_between_models: true -block_analysis: true -block_analysis_parameters: - skip_layers: [1,2,3,4,5,6,7,8,9] -analyse_individual_models: true \ No newline at end of file +analysis_type: "individual" +comparison_type: "block" \ No newline at end of file diff --git a/representations/store_representations.py b/representations/store_representations.py index c2c1964c..75ac099f 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -1,10 +1,7 @@ # WORK IN PROGRESS import click -import torch -import yaml - -from mergekit.config import MergeConfiguration +import h5py import logging import numpy as np @@ -12,9 +9,8 @@ import torch from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import AutoModelForCausalLM, AutoTokenizer import datasets - import os logging.basicConfig(level=logging.INFO) @@ -22,14 +18,8 @@ # Set seed torch.manual_seed(42) np.random.seed(42) - -import torch from typing import List - -import h5py -import torch import random -import numpy as np def load_batch_from_hdf5(model_name, batch_idx): with h5py.File('batches.h5', 'r') as h5file: @@ -59,8 +49,7 @@ def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tens last_non_padded_hidden_states.append(torch.cat(batch_last_tokens, dim=0)) return last_non_padded_hidden_states - -def store_representations(model_path, output_path, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): +def store_representations(model_path, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): device = "cuda" if torch.cuda.is_available() \ else "mps" if torch.backends.mps.is_available() \ @@ -85,7 +74,7 @@ def store_representations(model_path, output_path, dataset_name, batch_size, max dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) - output_name = f'Representations_{model.name_or_path}_{dataset_name}_{dataset_size}.h5'.replace("/","_") + output_name = f'{output_dir}/{model.name_or_path}_{dataset_name}_{dataset_size}.h5'.replace("/","_") assert not os.path.exists(output_name), f'{output_name} already exists.' with h5py.File(output_name, 'w') as h5file: @@ -101,7 +90,7 @@ def store_representations(model_path, output_path, dataset_name, batch_size, max # This adjustment is necessary for analyses focusing on the model's internal transformations last_non_padded_hidden_states = last_non_padded_hidden_states[1:] for layer, hidden_state in enumerate(last_non_padded_hidden_states): - layer_group = h5file.require_group(f'layer_{layer}') + layer_group = h5file.require_group(f'layer_{layer:03d}') file_name = f'batch_{batch_idx}.pt' layer_group.create_dataset(file_name, data=hidden_state.to('cpu'), compression="gzip") @@ -110,17 +99,16 @@ def store_representations(model_path, output_path, dataset_name, batch_size, max assert len(last_non_padded_hidden_states) == model.config.num_hidden_layers, "Length of last_non_padded_hidden_states \ does not match expected number of hidden layers." - @click.command() @click.option('--model_path', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') -@click.option('--output_path', default="./representations/", help='folder to store the result in.') +@click.option('--output_dir', default="./representations/stored_representations", help='folder to store the result in.') @click.option('--dataset_name', default="arcee-ai/sec-data-mini", help='dataset to use.') @click.option('--batch_size', default=8, help='batch size.') @click.option('--max_length', default=1024, help='maximum length of the input.') @click.option('--dataset_size', default=4000, help='size of the dataset.') @click.option('--dataset_column', default="text", help='column of the dataset to use.') @click.option('--dataset_subset', default="train", help='subset of the dataset to use.') -def main(model_path, output_path, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): - store_representations(model_path, output_path, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset) +def main(model_path, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + store_representations(model_path, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset) if __name__ == "__main__": main() From 7bc80016ecf3bfb1e784c2097d6dbbf78d457f40 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Wed, 31 Jul 2024 11:25:10 +0100 Subject: [PATCH 53/64] further restructuring --- representations/representation_metrics.py | 233 +++++++++++++--------- representations/representations.py | 156 --------------- 2 files changed, 143 insertions(+), 246 deletions(-) delete mode 100644 representations/representations.py diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 4e26ecb9..8348c37f 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -8,9 +8,11 @@ from typing import List, Dict, Any, Optional from tqdm import tqdm +from abc import ABC, abstractmethod + from mergekit.metric_methods.base import Results, Layer from mergekit.metric_methods.aggregator_metrics import ModelAnalysisType, LayerComparisonType, METRICS_TABLE - +from dataclasses import dataclass from mergekit.architecture import WeightInfo class LayerByIndex: @@ -112,9 +114,7 @@ def corresponding(representations_path_0, representation_path_1, metric_classes, results.add_layer(layer_results, layer_name) return results -def block(representations_path, block_size, metric_classes, results, device='cpu'): - if not results: - results = Results() +def block(representations_path, block_size, metric_classes, device='cpu'): out = {metric().__class__.__name__.lower(): [] for metric in metric_classes} with LayerByIndex(representations_path) as reps: for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {block_size}-block, Block Start at Layer', @@ -140,7 +140,6 @@ def block(representations_path, block_size, metric_classes, results, device='cpu # layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) out[metric.__class__.__name__.lower()].append(metric.aggregate()) - results.add_layer(layer_results, f"Block {idx} size {block_size}") for metric in out: out[metric] = np.array(out[metric]) return out @@ -174,106 +173,160 @@ def all_layers(representations_path_0, representations_path_1, metric_classes, r return results -@click.command() -@click.option('--config_yml', default="./representations/config.yml", help='path to the configuration file.') -def main(config_yml: str = "config.yml"): - with open(config_yml, "r", encoding="utf-8") as fp: - config = yaml.safe_load(fp) - - - representation_paths = [Path(model_path) for model_path in config['representation_paths']] - metrics_toggle = config['metrics'] - analysis_type = ModelAnalysisType(config['analysis_type']) - comparison_type = LayerComparisonType(config['comparison_type']) - - out_dir = representation_paths[0].parent.parent / 'stored_results' - - - - device = 'cuda' if torch.cuda.is_available() else \ - 'mps' if torch.backends.mps.is_available() else \ - 'cpu' - - - use_metrics = {name: METRICS_TABLE[name] - for name, enabled in metrics_toggle.items() +@dataclass +class Configuration: + representation_paths: List[Path] + metrics: Dict[str, bool] + analysis_type: ModelAnalysisType + comparison_type: LayerComparisonType + out_dir: Path + device: str + + @classmethod + def from_dict(cls, config_dict: Dict): + return cls( + representation_paths=[Path(path) for path in config_dict['representation_paths']], + metrics=config_dict['metrics'], + analysis_type=ModelAnalysisType(config_dict['analysis_type']), + comparison_type=LayerComparisonType(config_dict['comparison_type']), + out_dir=Path(config_dict['out_dir']), + device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + ).validate() + + def validate(self): + if self.comparison_type == LayerComparisonType.ALL_LAYERS: + raise ValueError("Comparison type 'all_layers' is not supported") + if self.analysis_type == ModelAnalysisType.COMPARISON and self.comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if self.analysis_type == ModelAnalysisType.INDIVIDUAL and self.comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") + return self + + +class Experiment(ABC): + @abstractmethod + def run(self, config: Configuration): + pass + +class SingleExperiment(Experiment): + def run(self, config: Configuration): + # Implementation for single experiment + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() if enabled} - - valid_experiment(analysis_type, comparison_type) - - final_results = [] - if analysis_type == ModelAnalysisType.INDIVIDUAL: - out_paths = [out_dir / f"{str(rep).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" for rep in representation_paths] - for out_path in out_paths: - assert not Path(out_path).exists(), f'{out_path} already exists.' - for representation_path in representation_paths: + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] + for representation_path in config.representation_paths: individual_results = Results() individual_results.model_paths = [representation_path] if not representation_path.exists(): raise FileNotFoundError(f"Representation file {representation_path} not found") - if comparison_type == LayerComparisonType.SINGLE: - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] - individual_results = single(representation_path, + individual_results = single(representation_path, metrics, results=individual_results, - device=device) + device=config.device) + + out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" + individual_results.save(out_path) - if comparison_type == LayerComparisonType.BLOCK: - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] - heatmaps = {} - for metric in metrics: - heatmaps[metric().__class__.__name__.lower()] = np.array([]) - for block_size in range(1, 9): - block_res = block(representations_path=representation_path, - block_size=block_size, - metric_classes=metrics, - results=individual_results, - device=device) - for metric in metrics: - heatmaps[metric().__class__.__name__.lower()] = np.append(heatmaps[metric().__class__.__name__.lower()], block_res[metric().__class__.__name__.lower()]) - - for metric in metrics: - result.across_layer_metrics[metric.__class__.__name__.lower()] = heatmaps[metric.__class__.__name__.lower()] # Definitely a simpler way to code this (X) +class CorrespondingExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} - if comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] - individual_results = corresponding(representations_path_0=representation_path, - representation_path_1=representation_path, - metric_classes=metrics, - results=individual_results, - device=device) + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] + comparison_results.model_paths = [config.representation_paths[0], + config.representation_paths[0] if config.analysis_type == ModelAnalysisType.INDIVIDUAL.value else + config.representation_paths[1]] + + comparison_results = Results() + for rep_0 in config.representation_paths: + for rep_1 in config.representation_paths: + if (rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL.value) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON.value): + + comparison_results = corresponding(representations_path_0=rep_0, + representation_path_1=rep_1, + metric_classes=metrics, + results=comparison_results, + device=config.device) + + out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in [rep_0, rep_1]])}+{str(list(use_metrics.keys()))}.json" + comparison_results.save(out_path) - final_results.append(individual_results) - - if analysis_type == ModelAnalysisType.COMPARISON: - out_paths = [out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in representation_paths])}+{str(list(use_metrics.keys()))}.json"] - assert not Path(out_path).exists(), f'{out_path} already exists.' +class BlockExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} - comparison_results = Results() - comparison_results.model_paths = representation_paths - - if comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] - comparison_results = corresponding(representations_path_0=representation_paths[0], - representation_path_1=representation_paths[1], - metric_classes=metrics, - results=comparison_results, - device=device) - if comparison_type == LayerComparisonType.ALL_LAYERS: - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL_LAYERS.value]] - comparison_results = all_layers(representations_path_0=representation_paths[0], - representations_path_1=representation_paths[1], - metric_classes=metrics, - results=comparison_results, - device=device) - final_results.append(comparison_results) - - for result, out_path in zip(final_results, out_paths): - result.save(out_path) #SAVE AS WE GO, NOT ALL AT THE END (X) + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] + heatmaps = {} + for representation_path in config.representation_paths: + block_results = Results() + block_results.model_paths = [representation_path] + if not representation_path.exists(): + raise FileNotFoundError(f"Representation file {representation_path} not found") + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = np.array([]) + for block_size in range(1, 9): + block_res = block(representations_path=representation_path, + block_size=block_size, + metric_classes=metrics, + device=config.device) + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = np.append(heatmaps[metric().__class__.__name__.lower()], block_res[metric().__class__.__name__.lower()]) + for metric in metrics: + block_results.across_layer_metrics[metric.__class__.__name__.lower()] = heatmaps[metric.__class__.__name__.lower()] # Definitely a simpler way to code this (X) + +class AllLayersExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL_LAYERS.value]] + comparison_results = Results() + comparison_results.model_paths = config.representation_paths + + comparison_results = all_layers(representations_path_0=config.representation_paths[0], + representations_path_1=config.representation_paths[1], + metric_classes=metrics, + results=comparison_results, + device=config.device) + + out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in config.representation_paths])}+{str(list(use_metrics.keys()))}.json" + comparison_results.save(out_path) + + +class ExperimentFactory: + experiments: Dict[str, Experiment] = { + "single": SingleExperiment, + "corresponding": CorrespondingExperiment, + "block": BlockExperiment, + "all_layers": AllLayersExperiment, + } + + @classmethod + def create(cls, experiment_type: str) -> Experiment: + experiment_class = cls.experiments.get(experiment_type) + if not experiment_class: + raise ValueError(f"Unknown experiment type: {experiment_type}") + return experiment_class() + +@click.command() +@click.option('--config_yml', default="./representations/config.yml", help='path to the configuration file.') +def main(config_yml: str = "config.yml"): + config = yaml.safe_load(open(config_yml, 'r')) + config['out_dir'] = Path(config['representation_paths'][0]).parent.parent / 'stored_results' + config = Configuration.from_dict(config) + + experiment = ExperimentFactory.create(config.comparison_type.name.lower()) + experiment.run(config) if __name__ == "__main__": main() diff --git a/representations/representations.py b/representations/representations.py deleted file mode 100644 index 00051352..00000000 --- a/representations/representations.py +++ /dev/null @@ -1,156 +0,0 @@ -#%% -import torch -import h5py -import numpy as np -import click -import yaml -from pathlib import Path -from typing import List, Dict, Any, Optional -from tqdm import tqdm -import torch.nn.functional as F - -from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer -from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank, compute_histogram, cosine_similarity_heatmap - - -from mergekit.architecture import WeightInfo -from mergekit.common import ModelReference -from mergekit.merge_methods.base import MergeMethod - -import enum - -class LayerByIndex: - def __init__(self, reps_path: str): - self.reps_path = reps_path - self.representations = None - self.layers = None - - def __enter__(self): - self.representations = h5py.File(self.reps_path, 'r') - self.layers = list(self.representations.keys()) - return self - - def __exit__(self, *args, **kwargs): - if self.representations: - self.representations.close() - - def __getitem__(self, idx: int): - return self.representations[self.layers[idx]] - - def __len__(self) -> int: - return len(self.layers) - - def __iter__(self): - return iter(self.representations[layer] for layer in self.layers) - - -class ModelAnalysisType(enum.Enum): - INDIVIDUAL = "individual" - COMPARISON = "comparison" - - -class LayerComparisonType(enum.Enum): - SINGLE = "single" # Layer i - BLOCK = "block" # Layer i in model 1 and layer i+(block size) in model 1 - CORRESPONDING_LAYERS = "corresponding_layers" # Layer i in model 1 and layer i in model 2 - ALL_LAYERS = "all_layers" # Layer i in model 1 and Layer j in model (1 or 2) - - -class MetricInput(enum.Enum): - ONE_SHOT = "one_shot" - BATCHES = "batches" - - -def valid_experiment(analysis_type, comparison_type, metric_input): - if comparison_type == LayerComparisonType.ALL_LAYERS: - raise ValueError("Comparison type 'all_layers' is not supported") - if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: - raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") - if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: - raise ValueError("Comparison type 'corresponding_layers' only supported for comparison analysis") - - -def layer_loader(representation_path): - with LayerByIndex(representation_path) as representations: - for layer in tqdm(representations, desc='Analysing Layer', - total=len(representations), leave=False, initial = 1): - yield layer - -def batch_loader(layer, device): - for batch in tqdm(layer, desc='processing batch', - total=len(layer), leave=False, initial = 1): - yield torch.tensor(layer[batch][:], device=device) - -# Experiment Loops -def single(representation_path: str): - for layer_idx, layer in enumerate(layer_loader(representation_path)): - for batch in batch_loader(layer, device="cpu"): - yield batch, layer_idx - -def block(representation_path: str, block_size: int, device: str = "cpu"): - with LayerByIndex(representation_path) as reps: - for layer_idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {block_size}-block, Block Start at Layer', - total=len(reps) - block_size, leave=False, initial = 1): - if layer_idx + block_size >= len(reps): - break - - block_end = reps[layer_idx + block_size] - - for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', - total=len(block_start), leave=False, initial = 1): - batch_0 = torch.tensor(block_start[batch_0][:]).to(device) - batch_1 = torch.tensor(block_end[batch_1][:]).to(device) - yield (batch_0, batch_1), layer_idx - -def corresponding_layers(representation_path_0: str, representation_path_1: str, device: str = "cpu"): - with LayerByIndex(representation_path_0) as reps_0, LayerByIndex(representation_path_1) as reps_1: - for layer_idx, (layer_0, layer_1) in enumerate(tqdm(zip(reps_0, reps_1), desc='Comparing Corresponding Layers', - total=len(reps_0), leave=False, initial = 1)): - for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', - total=len(layer_0), leave=False, initial = 1): - batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) - batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) - yield (batch_0, batch_1), layer_idx - -def all_layers(representation_path_0: str, representation_path_1: str, device: str = "cpu"): - with LayerByIndex(representation_path_0) as reps_0, LayerByIndex(representation_path_1) as reps_1: - for layer_0_idx, layer_0 in enumerate(tqdm(reps_0, desc='Model 0 Layers', - total=len(reps_0), leave=False, initial = 1)): - for layer_1_idx, layer_1 in enumerate(tqdm(reps_1, desc='Model 1 Layers', - total=len(reps_1), leave=False, initial = 1)): - for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', - total=len(layer_0), leave=False, initial = 1): - batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) - batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) - - yield (batch_0, batch_1), (layer_0_idx, layer_1_idx) - - -def main(): - representation_paths = [Path("/Users/elliotstein/Documents/Arcee/mergekit/representations/Representations_Qwen_Qwen2-7B-Instruct_microsoft_orca-math-word-problems-200k_4000.h5"), - Path("/Users/elliotstein/Documents/Arcee/mergekit/representations/Representations_arcee-ai_qwen2-7b-math-tess_microsoft_orca-math-word-problems-200k_4000.h5") - ] - - analysis_type = ModelAnalysisType.INDIVIDUAL - comparison_type = LayerComparisonType.SINGLE - metric_input = MetricInput.BATCHES - - valid_experiment(analysis_type, comparison_type, metric_input) - - for data, layer_idx in single(representation_paths[0]): - pass - - for data, layer_idx in block(representation_paths[0], 2): - pass - - for data, layer_idx in corresponding_layers(representation_paths[0], representation_paths[1]): - pass - - for data, layer_idx in all_layers(representation_paths[0], representation_paths[1]): - pass - -#%% - -if __name__ == "__main__": - main() - From a5c58115ec44c831d727e0dc5ba8ab12026fdf94 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 1 Aug 2024 11:12:32 +0100 Subject: [PATCH 54/64] further refinements to representations experiment --- representations/{ => configs}/config.yml | 0 representations/configs/config_comp_all.yml | 11 + .../configs/config_comp_corresponding.yml | 11 + representations/configs/config_i_all.yml | 11 + representations/configs/config_i_block.yml | 11 + representations/configs/config_i_single.yml | 11 + representations/experiment_setup.py | 346 ++++++++++++++++++ representations/representation_metrics.py | 327 +---------------- 8 files changed, 409 insertions(+), 319 deletions(-) rename representations/{ => configs}/config.yml (100%) create mode 100644 representations/configs/config_comp_all.yml create mode 100644 representations/configs/config_comp_corresponding.yml create mode 100644 representations/configs/config_i_all.yml create mode 100644 representations/configs/config_i_block.yml create mode 100644 representations/configs/config_i_single.yml create mode 100644 representations/experiment_setup.py diff --git a/representations/config.yml b/representations/configs/config.yml similarity index 100% rename from representations/config.yml rename to representations/configs/config.yml diff --git a/representations/configs/config_comp_all.yml b/representations/configs/config_comp_all.yml new file mode 100644 index 00000000..d6958d76 --- /dev/null +++ b/representations/configs/config_comp_all.yml @@ -0,0 +1,11 @@ +stored_representations: "/workspace/mergekit/representations/stored_representations" + +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + t-sne: true + +analysis_type: "comparison" +comparison_type: "all" \ No newline at end of file diff --git a/representations/configs/config_comp_corresponding.yml b/representations/configs/config_comp_corresponding.yml new file mode 100644 index 00000000..5eedc042 --- /dev/null +++ b/representations/configs/config_comp_corresponding.yml @@ -0,0 +1,11 @@ +stored_representations: "/workspace/mergekit/representations/stored_representations" + +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + t-sne: true + +analysis_type: "comparison" +comparison_type: "corresponding" \ No newline at end of file diff --git a/representations/configs/config_i_all.yml b/representations/configs/config_i_all.yml new file mode 100644 index 00000000..5d894605 --- /dev/null +++ b/representations/configs/config_i_all.yml @@ -0,0 +1,11 @@ +stored_representations: "/workspace/mergekit/representations/stored_representations" + +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + t-sne: true + +analysis_type: "individual" +comparison_type: "all" \ No newline at end of file diff --git a/representations/configs/config_i_block.yml b/representations/configs/config_i_block.yml new file mode 100644 index 00000000..e86e57df --- /dev/null +++ b/representations/configs/config_i_block.yml @@ -0,0 +1,11 @@ +stored_representations: "/workspace/mergekit/representations/stored_representations" + +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + t-sne: true + +analysis_type: "individual" +comparison_type: "block" \ No newline at end of file diff --git a/representations/configs/config_i_single.yml b/representations/configs/config_i_single.yml new file mode 100644 index 00000000..191975ff --- /dev/null +++ b/representations/configs/config_i_single.yml @@ -0,0 +1,11 @@ +stored_representations: "/workspace/mergekit/representations/stored_representations" + +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + t-sne: true + +analysis_type: "individual" +comparison_type: "single" \ No newline at end of file diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py new file mode 100644 index 00000000..bba86347 --- /dev/null +++ b/representations/experiment_setup.py @@ -0,0 +1,346 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import List, Dict, Any, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from mergekit.architecture import WeightInfo +from mergekit.metric_methods.base import Results, Layer +from mergekit.metric_methods.aggregator_metrics import ModelAnalysisType, LayerComparisonType, METRICS_TABLE, MetricAggregator + +def check_memory(h5file): + # Check if full data can be loaded into memory + return True + +class LayerByIndex: + def __init__(self, reps_path: str, load_into_memory: bool = True): + self.reps_path = reps_path + self.representations = None + self.layers = None + self.load_into_memory = load_into_memory + self.in_memory_data = None + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + def __enter__(self): + self.representations = h5py.File(self.reps_path, 'r') + self.layers = list(self.representations.keys()) + + if self.load_into_memory: + print("Loading representations into memory") + self.in_memory_data = {layer: {} for layer in self.layers} + for layer_name in tqdm(self.layers, leave=False, initial = 1, desc='Loading Layer', total=len(self.layers)): + for batch_name, batch_data in self.representations[layer_name].items(): + data = torch.tensor(batch_data[...]).to(self.device) + self.in_memory_data[layer_name][batch_name] = data + return self + + def __exit__(self, *args, **kwargs): + if self.representations: + self.representations.close() + self.in_memory_data = None + + def __getitem__(self, idx: int): + if self.load_into_memory: + return self.in_memory_data[self.layers[idx]] + else: + return self.representations[self.layers[idx]] #(X) + + def __len__(self) -> int: + return len(self.layers) + + def __iter__(self): + if self.load_into_memory: + return iter(self.in_memory_data[layer] for layer in self.layers) + else: + return iter(self.representations[layer] for layer in self.layers) + +def valid_experiment(analysis_type, comparison_type): + if comparison_type == LayerComparisonType.ALL_LAYERS: + raise ValueError("Comparison type 'all_layers' is not supported") + if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") + +# Experiment Loops +def single(representations: LayerByIndex, metric_classes: List[MetricAggregator], results: Optional[Results], device='cpu'): + if not results: + results = Results() + # with LayerByIndex(representations_path) as representations: + for layer in tqdm(representations, desc='Analysing Layer', + total=len(representations), leave=False, initial = 1): + layer_name = layer.name.split('/')[-1] + num = f"{int(layer_name.split('_')[-1]):03d}" + layer_name = f"Layer {num}" + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch in tqdm(layer, desc='Batch', + total=len(layer), leave=False, initial = 1): + batch = torch.tensor(layer[batch][:], device=device) + + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch) + + layer_results = Layer(WeightInfo(name=layer_name)) + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + # metric.clear() + + results.add_layer(layer_results, layer_name) + + return results + +def corresponding(representations_0: LayerByIndex, representations_1: LayerByIndex, metric_classes: List[MetricAggregator], results: Optional[Results], device='cpu'): + if not results: + results = Results() + + for layer_0, layer_1 in tqdm(zip(representations_0, representations_1), + desc='Comparing Representations at layer', + total=len(representations_0), initial = 1): + + layer_0_name = layer_0.name.split('/')[-1] + layer_1_name = layer_1.name.split('/')[-1] + if layer_0_name != layer_1_name: + raise ValueError(f'Layer mismatch: {layer_0_name} != {layer_1_name}') + num = f"{int(layer_0_name.split('_')[-1]):03d}" + layer_name = f"Layer {num}" + + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), + desc='Batch', total=len(layer_0), leave=False, initial = 1): + batch_0 = torch.tensor(layer_0[batch_0][:], device=device) + batch_1 = torch.tensor(layer_1[batch_1][:], device=device) + + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + layer_results = Layer(WeightInfo(name=layer_name)) + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + # metric.clear() + + results.add_layer(layer_results, layer_name) + + return results + +def block(representations, block_size, metric_classes, device='cpu'): + out = {metric().__class__.__name__.lower(): [] for metric in metric_classes} + for idx, block_start in tqdm(enumerate(representations), desc=f'Comparing {block_size}-block, Block Start at Layer', + total=len(representations) - block_size, leave=False, initial = 1): + if idx + block_size >= len(representations): + break + + # Create metrics + metrics = [metric_class(device=device) for metric_class in metric_classes] + block_end = representations[idx + block_size] + + for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', + total=len(block_start), leave=False, initial = 1): + batch_0 = torch.tensor(block_start[batch_0][:]).to(device) + batch_1 = torch.tensor(block_end[batch_1][:]).to(device) + + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + # Aggregate metrics and add to results + layer_results = Layer(WeightInfo(name=f"Block {idx} size {block_size}")) + for metric in metrics: + # layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + out[metric.__class__.__name__.lower()].append(metric.aggregate()) + + for metric in out: + out[metric] = np.array(out[metric]) + return out + +def all_layers(representations_0, representations_1, metric_classes, results, device='cpu'): + for idx_0, layer_0 in enumerate(tqdm(representations_0, desc='Model 0 Layers', + total=len(representations_0), leave=False, initial = 1)): + for idx_1, layer_1 in enumerate(tqdm(representations_1, desc='Model 1 Layers', + total=len(representations_1), leave=False, initial = 1)): + if len(layer_0) != len(layer_1): + raise ValueError(f'Layer mismatch: {len(layer_0)} != {len(layer_1)}') + + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', + total=len(layer_0), leave=False, initial = 1): + batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) + batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) + + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + layer_results = Layer(WeightInfo(name=f"Layer {idx_0} - Layer {idx_1}")) + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + + results.add_layer(layer_results, f"Layer {idx_0} - Layer {idx_1}") + + return results + +@dataclass +class Configuration: + representation_paths: List[Path] + metrics: Dict[str, bool] + analysis_type: ModelAnalysisType + comparison_type: LayerComparisonType + out_dir: Path + device: str + data: Optional[LayerByIndex] = None + + @classmethod + def from_dict(cls, config_dict: Dict): + return cls( + representation_paths=list(config_dict['stored_representations'].iterdir()), + metrics=config_dict['metrics'], + analysis_type=ModelAnalysisType(config_dict['analysis_type']), + comparison_type=LayerComparisonType(config_dict['comparison_type']), + out_dir=Path(config_dict['out_dir']), + device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + ).validate() + + def validate(self): + if self.comparison_type == LayerComparisonType.ALL_LAYERS: + raise ValueError("Comparison type 'all_layers' is not supported") + if self.analysis_type == ModelAnalysisType.COMPARISON and self.comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if self.analysis_type == ModelAnalysisType.INDIVIDUAL and self.comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") + return self + +class Experiment(ABC): + @abstractmethod + def run(self, config: Configuration): + pass + +class SingleExperiment(Experiment): + def run(self, config: Configuration): + # Implementation for single experiment + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] + for representation_path in config.representation_paths: + individual_results = Results() + individual_results.model_paths = [representation_path] + + if not representation_path.exists(): + raise FileNotFoundError(f"Representation file {representation_path} not found") + + with LayerByIndex(representation_path, load_into_memory = check_memory(representation_path)) as representations: + individual_results = single(representations, + metrics, + results=individual_results, + device=config.device) + + out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" + individual_results.save(out_path) + +class CorrespondingExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] + comparison_results = Results() + stop = False + for rep_0 in config.representation_paths: + for rep_1 in config.representation_paths: + if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL.value) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON.value) and not stop): + if rep_0 != rep_1: + stop = True + + with LayerByIndex(rep_0) as representations_0, LayerByIndex(rep_1) as representations_1: + comparison_results = corresponding(representations_0=representations_0, + representations_1=representations_1, + metric_classes=metrics, + results=comparison_results, + device=config.device) + comparison_results.model_paths = [rep_0, rep_1] if rep_0 != rep_1 else [rep_0] + + out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in [rep_0, rep_1]])}+{str(list(use_metrics.keys()))}.json" + comparison_results.save(out_path) + +class BlockExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] + heatmaps = {} + for representation_path in config.representation_paths: + block_results = Results() + block_results.model_paths = [representation_path] + if not representation_path.exists(): + raise FileNotFoundError(f"Representation file {representation_path} not found") + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = np.array([]) + with LayerByIndex(representation_path, load_into_memory = check_memory(representation_path)) as representations: + for block_size in range(1, 9): + block_res = block(representations=representations, + block_size=block_size, + metric_classes=metrics, + device=config.device) + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = np.append(heatmaps[metric().__class__.__name__.lower()], block_res[metric().__class__.__name__.lower()]) + + for metric in metrics: + block_results.across_layer_metrics[metric().__class__.__name__.lower()] = heatmaps[metric().__class__.__name__.lower()] # Definitely a simpler way to code this (X) + + out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" + block_results.save(out_path) + +class AllLayersExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL_LAYERS.value]] + + stop = False + for rep_0 in config.representation_paths: + for rep_1 in config.representation_paths: + if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL.value) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON.value) and not stop): + if rep_0 != rep_1: + stop = True + with LayerByIndex(rep_0) as representations_0, LayerByIndex(rep_1) as representations_1: + + comparison_results = all_layers(representations_0=representations_0, + representations_1=representations_1, + metric_classes=metrics, + results=comparison_results, + device=config.device) + comparison_results.model_paths = [rep_0, rep_1] + + out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in config.representation_paths])}+{str(list(use_metrics.keys()))}.json" + comparison_results.save(out_path) + + +class ExperimentFactory: + experiments: Dict[str, Experiment] = { + "single": SingleExperiment, + "corresponding": CorrespondingExperiment, + "block": BlockExperiment, + "all_layers": AllLayersExperiment, + } + + @classmethod + def create(cls, experiment_type: str) -> Experiment: + experiment_class = cls.experiments.get(experiment_type) + if not experiment_class: + raise ValueError(f"Unknown experiment type: {experiment_type}") + return experiment_class() diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 8348c37f..2c8b30ba 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -1,328 +1,17 @@ -#%% -import torch -import h5py -import numpy as np -import click -import yaml from pathlib import Path -from typing import List, Dict, Any, Optional -from tqdm import tqdm - -from abc import ABC, abstractmethod - -from mergekit.metric_methods.base import Results, Layer -from mergekit.metric_methods.aggregator_metrics import ModelAnalysisType, LayerComparisonType, METRICS_TABLE -from dataclasses import dataclass -from mergekit.architecture import WeightInfo - -class LayerByIndex: - def __init__(self, reps_path: str): - self.reps_path = reps_path - self.representations = None - self.layers = None - - def __enter__(self): - self.representations = h5py.File(self.reps_path, 'r') - self.layers = list(self.representations.keys()) - return self - - def __exit__(self, *args, **kwargs): - if self.representations: - self.representations.close() - - def __getitem__(self, idx: int): - return self.representations[self.layers[idx]] - - def __len__(self) -> int: - return len(self.layers) - - def __iter__(self): - return iter(self.representations[layer] for layer in self.layers) - -def valid_experiment(analysis_type, comparison_type): - if comparison_type == LayerComparisonType.ALL_LAYERS: - raise ValueError("Comparison type 'all_layers' is not supported") - if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: - raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") - if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: - raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") - -# Experiment Loops -def single(representations_path, metric_classes, results, device='cpu'): - if not results: - results = Results() - with LayerByIndex(representations_path) as representations: - for layer in tqdm(representations, desc='Analysing Layer', - total=len(representations), leave=False, initial = 1): - layer_name = layer.name.split('/')[-1] - num = f"{int(layer_name.split('_')[-1]):03d}" - layer_name = f"Layer {num}" - metrics = [metric_class(device=device) for metric_class in metric_classes] - - for batch in tqdm(layer, desc='Batch', - total=len(layer), leave=False, initial = 1): - batch = torch.tensor(layer[batch][:], device=device) - - # Calculate the metrics for each batch - for metric in metrics: - metric.process_batch(batch) - - layer_results = Layer(WeightInfo(name=layer_name)) - # Aggregate over the batches and add to the layer results - for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - # metric.clear() - - results.add_layer(layer_results, layer_name) - - return results - -def corresponding(representations_path_0, representation_path_1, metric_classes, results, device='cpu'): - if not results: - results = Results() - with LayerByIndex(representations_path_0) as representations_a, \ - LayerByIndex(representation_path_1) as representations_b: - - for layer_a, layer_b in tqdm(zip(representations_a, representations_b), - desc='Comparing Representations at layer', - total=len(representations_a), initial = 1): - - layer_a_name = layer_a.name.split('/')[-1] - layer_b_name = layer_b.name.split('/')[-1] - if layer_a_name != layer_b_name: - raise ValueError(f'Layer mismatch: {layer_a_name} != {layer_b_name}') - num = f"{int(layer_a_name.split('_')[-1]):03d}" - layer_name = f"Layer {num}" - - metrics = [metric_class(device=device) for metric_class in metric_classes] - - for batch_a, batch_b in tqdm(zip(layer_a, layer_b), - desc='Batch', total=len(layer_a), leave=False, initial = 1): - batch_a = torch.tensor(layer_a[batch_a][:], device=device) - batch_b = torch.tensor(layer_b[batch_b][:], device=device) - - # Calculate the metrics for each batch - for metric in metrics: - metric.process_batch(batch_a, batch_b) - - layer_results = Layer(WeightInfo(name=layer_name)) - # Aggregate over the batches and add to the layer results - for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - # metric.clear() - - results.add_layer(layer_results, layer_name) - return results - -def block(representations_path, block_size, metric_classes, device='cpu'): - out = {metric().__class__.__name__.lower(): [] for metric in metric_classes} - with LayerByIndex(representations_path) as reps: - for idx, block_start in tqdm(enumerate(reps), desc=f'Comparing {block_size}-block, Block Start at Layer', - total=len(reps) - block_size, leave=False, initial = 1): - if idx + block_size >= len(reps): - break - - # Create metrics - metrics = [metric_class(device=device) for metric_class in metric_classes] - block_end = reps[idx + block_size] - - for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', - total=len(block_start), leave=False, initial = 1): - batch_0 = torch.tensor(block_start[batch_0][:]).to(device) - batch_1 = torch.tensor(block_end[batch_1][:]).to(device) - - for metric in metrics: - metric.process_batch(batch_0, batch_1) - - # Aggregate metrics and add to results - layer_results = Layer(WeightInfo(name=f"Block {idx} size {block_size}")) - for metric in metrics: - # layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - out[metric.__class__.__name__.lower()].append(metric.aggregate()) - for metric in out: - out[metric] = np.array(out[metric]) - return out - -def all_layers(representations_path_0, representations_path_1, metric_classes, results, device='cpu'): - if not results: - results = Results() - with LayerByIndex(representations_path_0) as reps_0, LayerByIndex(representations_path_1) as reps_1: - for idx_0, layer_0 in enumerate(tqdm(reps_0, desc='Model 0 Layers', - total=len(reps_0), leave=False, initial = 1)): - for idx_1, layer_1 in enumerate(tqdm(reps_1, desc='Model 1 Layers', - total=len(reps_1), leave=False, initial = 1)): - if len(layer_0) != len(layer_1): - raise ValueError(f'Layer mismatch: {len(layer_0)} != {len(layer_1)}') - - metrics = [metric_class(device=device) for metric_class in metric_classes] - - for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', - total=len(layer_0), leave=False, initial = 1): - batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) - batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) - - for metric in metrics: - metric.process_batch(batch_0, batch_1) - - layer_results = Layer(WeightInfo(name=f"Layer {idx_0} - Layer {idx_1}")) - for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - - results.add_layer(layer_results, f"Layer {idx_0} - Layer {idx_1}") - - return results - -@dataclass -class Configuration: - representation_paths: List[Path] - metrics: Dict[str, bool] - analysis_type: ModelAnalysisType - comparison_type: LayerComparisonType - out_dir: Path - device: str - - @classmethod - def from_dict(cls, config_dict: Dict): - return cls( - representation_paths=[Path(path) for path in config_dict['representation_paths']], - metrics=config_dict['metrics'], - analysis_type=ModelAnalysisType(config_dict['analysis_type']), - comparison_type=LayerComparisonType(config_dict['comparison_type']), - out_dir=Path(config_dict['out_dir']), - device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' - ).validate() - - def validate(self): - if self.comparison_type == LayerComparisonType.ALL_LAYERS: - raise ValueError("Comparison type 'all_layers' is not supported") - if self.analysis_type == ModelAnalysisType.COMPARISON and self.comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: - raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") - if self.analysis_type == ModelAnalysisType.INDIVIDUAL and self.comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: - raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") - return self - - -class Experiment(ABC): - @abstractmethod - def run(self, config: Configuration): - pass - -class SingleExperiment(Experiment): - def run(self, config: Configuration): - # Implementation for single experiment - use_metrics = {name: METRICS_TABLE[name] - for name, enabled in config.metrics.items() - if enabled} - - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] - for representation_path in config.representation_paths: - individual_results = Results() - individual_results.model_paths = [representation_path] - - if not representation_path.exists(): - raise FileNotFoundError(f"Representation file {representation_path} not found") - - individual_results = single(representation_path, - metrics, - results=individual_results, - device=config.device) - - out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" - individual_results.save(out_path) - -class CorrespondingExperiment(Experiment): - def run(self, config: Configuration): - use_metrics = {name: METRICS_TABLE[name] - for name, enabled in config.metrics.items() - if enabled} - - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] - comparison_results.model_paths = [config.representation_paths[0], - config.representation_paths[0] if config.analysis_type == ModelAnalysisType.INDIVIDUAL.value else - config.representation_paths[1]] - - comparison_results = Results() - for rep_0 in config.representation_paths: - for rep_1 in config.representation_paths: - if (rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL.value) or \ - (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON.value): - - comparison_results = corresponding(representations_path_0=rep_0, - representation_path_1=rep_1, - metric_classes=metrics, - results=comparison_results, - device=config.device) - - out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in [rep_0, rep_1]])}+{str(list(use_metrics.keys()))}.json" - comparison_results.save(out_path) - -class BlockExperiment(Experiment): - def run(self, config: Configuration): - use_metrics = {name: METRICS_TABLE[name] - for name, enabled in config.metrics.items() - if enabled} - - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] - heatmaps = {} - for representation_path in config.representation_paths: - block_results = Results() - block_results.model_paths = [representation_path] - if not representation_path.exists(): - raise FileNotFoundError(f"Representation file {representation_path} not found") - for metric in metrics: - heatmaps[metric().__class__.__name__.lower()] = np.array([]) - for block_size in range(1, 9): - block_res = block(representations_path=representation_path, - block_size=block_size, - metric_classes=metrics, - device=config.device) - for metric in metrics: - heatmaps[metric().__class__.__name__.lower()] = np.append(heatmaps[metric().__class__.__name__.lower()], block_res[metric().__class__.__name__.lower()]) - - for metric in metrics: - block_results.across_layer_metrics[metric.__class__.__name__.lower()] = heatmaps[metric.__class__.__name__.lower()] # Definitely a simpler way to code this (X) - -class AllLayersExperiment(Experiment): - def run(self, config: Configuration): - use_metrics = {name: METRICS_TABLE[name] - for name, enabled in config.metrics.items() - if enabled} - - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL_LAYERS.value]] - comparison_results = Results() - comparison_results.model_paths = config.representation_paths - - comparison_results = all_layers(representations_path_0=config.representation_paths[0], - representations_path_1=config.representation_paths[1], - metric_classes=metrics, - results=comparison_results, - device=config.device) - - out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in config.representation_paths])}+{str(list(use_metrics.keys()))}.json" - comparison_results.save(out_path) - - -class ExperimentFactory: - experiments: Dict[str, Experiment] = { - "single": SingleExperiment, - "corresponding": CorrespondingExperiment, - "block": BlockExperiment, - "all_layers": AllLayersExperiment, - } +import click +import yaml - @classmethod - def create(cls, experiment_type: str) -> Experiment: - experiment_class = cls.experiments.get(experiment_type) - if not experiment_class: - raise ValueError(f"Unknown experiment type: {experiment_type}") - return experiment_class() +from experiment_setup import Configuration, ExperimentFactory @click.command() -@click.option('--config_yml', default="./representations/config.yml", help='path to the configuration file.') +@click.option('--config_yml', default="config_i_block.yml", help='path to the configuration file.') def main(config_yml: str = "config.yml"): - config = yaml.safe_load(open(config_yml, 'r')) - config['out_dir'] = Path(config['representation_paths'][0]).parent.parent / 'stored_results' + mergekit_root = Path(__file__).parent.parent + config = yaml.safe_load(open(mergekit_root / 'representations' / 'configs' / config_yml, 'r')) + config['out_dir'] = mergekit_root / 'representations' / 'stored_results' + config['stored_representations'] = mergekit_root / 'representations' / 'stored_representations' config = Configuration.from_dict(config) experiment = ExperimentFactory.create(config.comparison_type.name.lower()) From 3d3fd33139ee134136f54cf080493f632d5ef706 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 1 Aug 2024 11:13:14 +0100 Subject: [PATCH 55/64] added necessary imports and modules --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8814984d..1c916ff4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] interactive_plot = ["networkx", "plotly", "matplotlib", "dash"] -representations = ["h5py", "datasets", "bitsandbytes"] +representations = ["h5py", "datasets", "bitsandbytes", "scikit-learn] [project.urls] repository = "https://github.com/cg123/mergekit" @@ -53,6 +53,7 @@ packages = [ "mergekit", "mergekit.io", "mergekit.merge_methods", + "mergekit.metric_methods", "mergekit.moe", "mergekit.scripts", "mergekit.evo", From 2dfa4f9ecd4e54f4fac7f9080d71798430a9e5f7 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 1 Aug 2024 15:32:32 +0100 Subject: [PATCH 56/64] minor restructuring and refactoring --- pyproject.toml | 3 +- representations/experiment_setup.py | 91 ++++++++++--------- representations/representation_metrics.py | 9 +- .../visualise_representation_results.py | 10 +- 4 files changed, 60 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c916ff4..fde49e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] interactive_plot = ["networkx", "plotly", "matplotlib", "dash"] -representations = ["h5py", "datasets", "bitsandbytes", "scikit-learn] +representations = ["h5py", "datasets", "bitsandbytes", "scikit-learn"] [project.urls] repository = "https://github.com/cg123/mergekit" @@ -54,6 +54,7 @@ packages = [ "mergekit.io", "mergekit.merge_methods", "mergekit.metric_methods", + "mergekit.plot_tools", "mergekit.moe", "mergekit.scripts", "mergekit.evo", diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py index bba86347..671b0ae0 100644 --- a/representations/experiment_setup.py +++ b/representations/experiment_setup.py @@ -6,6 +6,7 @@ import numpy as np import torch from tqdm import tqdm +import h5py from mergekit.architecture import WeightInfo from mergekit.metric_methods.base import Results, Layer @@ -51,35 +52,39 @@ def __getitem__(self, idx: int): def __len__(self) -> int: return len(self.layers) + # def __iter__(self): + # if self.load_into_memory: + # return ((layer, self.in_memory_data[layer]) for layer in self.layers) + # else: + # return ((layer, self.representations[layer]) for layer in self.layers) def __iter__(self): if self.load_into_memory: - return iter(self.in_memory_data[layer] for layer in self.layers) + return ((layer, self.in_memory_data[layer]) for layer in self.layers) else: - return iter(self.representations[layer] for layer in self.layers) + return ((layer, self.representations[layer]) for layer in self.layers) def valid_experiment(analysis_type, comparison_type): - if comparison_type == LayerComparisonType.ALL_LAYERS: + if comparison_type == LayerComparisonType.ALL: raise ValueError("Comparison type 'all_layers' is not supported") if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") - if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING: raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") # Experiment Loops def single(representations: LayerByIndex, metric_classes: List[MetricAggregator], results: Optional[Results], device='cpu'): if not results: results = Results() - # with LayerByIndex(representations_path) as representations: - for layer in tqdm(representations, desc='Analysing Layer', + for layer_name, layer in tqdm(representations, desc='Analysing Layer', total=len(representations), leave=False, initial = 1): - layer_name = layer.name.split('/')[-1] - num = f"{int(layer_name.split('_')[-1]):03d}" - layer_name = f"Layer {num}" + # layer_name = layer.name.split('/')[-1] # layer is a dictionary of batches, doens't have .name attribute. + # num = f"{int(layer_name.split('_')[-1]):03d}" + # layer_name = f"Layer {num}" metrics = [metric_class(device=device) for metric_class in metric_classes] - for batch in tqdm(layer, desc='Batch', + for batch in tqdm(layer.values(), desc='Batch', total=len(layer), leave=False, initial = 1): - batch = torch.tensor(layer[batch][:], device=device) + # batch = torch.tensor(layer[batch][:], device=device) # Redundant now as LayerByIndex is already returning torch tensors # Calculate the metrics for each batch for metric in metrics: @@ -99,12 +104,12 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd if not results: results = Results() - for layer_0, layer_1 in tqdm(zip(representations_0, representations_1), + for (layer_0_name, layer_0), (layer_1_name, layer_1) in tqdm(zip(representations_0, representations_1), desc='Comparing Representations at layer', total=len(representations_0), initial = 1): - layer_0_name = layer_0.name.split('/')[-1] - layer_1_name = layer_1.name.split('/')[-1] + # layer_0_name = layer_0.name.split('/')[-1] + # layer_1_name = layer_1.name.split('/')[-1] if layer_0_name != layer_1_name: raise ValueError(f'Layer mismatch: {layer_0_name} != {layer_1_name}') num = f"{int(layer_0_name.split('_')[-1]):03d}" @@ -112,10 +117,10 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd metrics = [metric_class(device=device) for metric_class in metric_classes] - for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), + for batch_0, batch_1 in tqdm(zip(layer_0.values(), layer_1.values()), desc='Batch', total=len(layer_0), leave=False, initial = 1): - batch_0 = torch.tensor(layer_0[batch_0][:], device=device) - batch_1 = torch.tensor(layer_1[batch_1][:], device=device) + # batch_0 = torch.tensor(layer_0[batch_0][:], device=device) + # batch_1 = torch.tensor(layer_1[batch_1][:], device=device) # Calculate the metrics for each batch for metric in metrics: @@ -133,7 +138,7 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd def block(representations, block_size, metric_classes, device='cpu'): out = {metric().__class__.__name__.lower(): [] for metric in metric_classes} - for idx, block_start in tqdm(enumerate(representations), desc=f'Comparing {block_size}-block, Block Start at Layer', + for idx, (block_start_name, block_start) in tqdm(enumerate(representations), desc=f'Comparing {block_size}-block, Block Start at Layer', total=len(representations) - block_size, leave=False, initial = 1): if idx + block_size >= len(representations): break @@ -142,18 +147,16 @@ def block(representations, block_size, metric_classes, device='cpu'): metrics = [metric_class(device=device) for metric_class in metric_classes] block_end = representations[idx + block_size] - for batch_0, batch_1 in tqdm(zip(block_start, block_end), desc='Batch', + for batch_0, batch_1 in tqdm(zip(block_start.values(), block_end.values()), desc='Batch', total=len(block_start), leave=False, initial = 1): - batch_0 = torch.tensor(block_start[batch_0][:]).to(device) - batch_1 = torch.tensor(block_end[batch_1][:]).to(device) + # batch_0 = torch.tensor(block_start[batch_0][:]).to(device) + # batch_1 = torch.tensor(block_end[batch_1][:]).to(device) for metric in metrics: metric.process_batch(batch_0, batch_1) # Aggregate metrics and add to results - layer_results = Layer(WeightInfo(name=f"Block {idx} size {block_size}")) for metric in metrics: - # layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) out[metric.__class__.__name__.lower()].append(metric.aggregate()) for metric in out: @@ -161,28 +164,28 @@ def block(representations, block_size, metric_classes, device='cpu'): return out def all_layers(representations_0, representations_1, metric_classes, results, device='cpu'): - for idx_0, layer_0 in enumerate(tqdm(representations_0, desc='Model 0 Layers', - total=len(representations_0), leave=False, initial = 1)): - for idx_1, layer_1 in enumerate(tqdm(representations_1, desc='Model 1 Layers', - total=len(representations_1), leave=False, initial = 1)): + for layer_0_name, layer_0 in tqdm(representations_0, desc='Model 0 Layers', + total=len(representations_0), leave=False, initial = 1): + for layer_1_name, layer_1 in tqdm(representations_1, desc='Model 1 Layers', + total=len(representations_1), leave=False, initial = 1): if len(layer_0) != len(layer_1): raise ValueError(f'Layer mismatch: {len(layer_0)} != {len(layer_1)}') metrics = [metric_class(device=device) for metric_class in metric_classes] - for batch_0, batch_1 in tqdm(zip(layer_0, layer_1), desc='Batch', + for batch_0, batch_1 in tqdm(zip(layer_0.values(), layer_1.values()), desc='Batch', total=len(layer_0), leave=False, initial = 1): - batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) - batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) + # batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) + # batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) for metric in metrics: metric.process_batch(batch_0, batch_1) - layer_results = Layer(WeightInfo(name=f"Layer {idx_0} - Layer {idx_1}")) + layer_results = Layer(WeightInfo(name=f"{layer_0_name} - {layer_1_name}")) for metric in metrics: layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - results.add_layer(layer_results, f"Layer {idx_0} - Layer {idx_1}") + results.add_layer(layer_results, f"{layer_0_name} - {layer_1_name}") return results @@ -208,11 +211,11 @@ def from_dict(cls, config_dict: Dict): ).validate() def validate(self): - if self.comparison_type == LayerComparisonType.ALL_LAYERS: + if self.comparison_type == LayerComparisonType.ALL: raise ValueError("Comparison type 'all_layers' is not supported") if self.analysis_type == ModelAnalysisType.COMPARISON and self.comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") - if self.analysis_type == ModelAnalysisType.INDIVIDUAL and self.comparison_type == LayerComparisonType.CORRESPONDING_LAYERS: + if self.analysis_type == ModelAnalysisType.INDIVIDUAL and self.comparison_type == LayerComparisonType.CORRESPONDING: raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") return self @@ -242,7 +245,7 @@ def run(self, config: Configuration): results=individual_results, device=config.device) - out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" + out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{config.analysis_type}+{config.comparison_type}.json" individual_results.save(out_path) class CorrespondingExperiment(Experiment): @@ -251,13 +254,13 @@ def run(self, config: Configuration): for name, enabled in config.metrics.items() if enabled} - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING_LAYERS.value]] + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING.value]] comparison_results = Results() stop = False for rep_0 in config.representation_paths: for rep_1 in config.representation_paths: - if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL.value) or \ - (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON.value) and not stop): + if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON) and not stop): if rep_0 != rep_1: stop = True @@ -269,7 +272,7 @@ def run(self, config: Configuration): device=config.device) comparison_results.model_paths = [rep_0, rep_1] if rep_0 != rep_1 else [rep_0] - out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in [rep_0, rep_1]])}+{str(list(use_metrics.keys()))}.json" + out_path = config.out_dir / f"{str([str(rep).split('/')[-1].split('.')[0] for rep in [rep_0, rep_1]])}+{config.analysis_type}+{config.comparison_type}.json" comparison_results.save(out_path) class BlockExperiment(Experiment): @@ -299,7 +302,7 @@ def run(self, config: Configuration): for metric in metrics: block_results.across_layer_metrics[metric().__class__.__name__.lower()] = heatmaps[metric().__class__.__name__.lower()] # Definitely a simpler way to code this (X) - out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{str(list(use_metrics.keys()))}.json" + out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{config.analysis_type}+{config.comparison_type}.json" block_results.save(out_path) class AllLayersExperiment(Experiment): @@ -308,13 +311,13 @@ def run(self, config: Configuration): for name, enabled in config.metrics.items() if enabled} - metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL_LAYERS.value]] + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL.value]] stop = False for rep_0 in config.representation_paths: for rep_1 in config.representation_paths: - if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL.value) or \ - (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON.value) and not stop): + if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON) and not stop): if rep_0 != rep_1: stop = True with LayerByIndex(rep_0) as representations_0, LayerByIndex(rep_1) as representations_1: @@ -326,7 +329,7 @@ def run(self, config: Configuration): device=config.device) comparison_results.model_paths = [rep_0, rep_1] - out_path = config.out_dir / f"/{str([str(rep).split('/')[-1].split('.')[0] for rep in config.representation_paths])}+{str(list(use_metrics.keys()))}.json" + out_path = config.out_dir / f"{str([str(rep).split('/')[-1].split('.')[0] for rep in config.representation_paths])}+{config.analysis_type}+{config.comparison_type}.json" comparison_results.save(out_path) diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 2c8b30ba..675b2f58 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -5,9 +5,7 @@ from experiment_setup import Configuration, ExperimentFactory -@click.command() -@click.option('--config_yml', default="config_i_block.yml", help='path to the configuration file.') -def main(config_yml: str = "config.yml"): +def run(config_yml: str = "config.yml"): mergekit_root = Path(__file__).parent.parent config = yaml.safe_load(open(mergekit_root / 'representations' / 'configs' / config_yml, 'r')) config['out_dir'] = mergekit_root / 'representations' / 'stored_results' @@ -17,5 +15,10 @@ def main(config_yml: str = "config.yml"): experiment = ExperimentFactory.create(config.comparison_type.name.lower()) experiment.run(config) +@click.command() +@click.option('--config_yml', default="config_i_single.yml", help='path to the configuration file.') +def main(config_yml: str = "config.yml"): + run(config_yml) + if __name__ == "__main__": main() diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py index 7569fe55..9119be2f 100644 --- a/representations/visualise_representation_results.py +++ b/representations/visualise_representation_results.py @@ -1,11 +1,8 @@ import click from mergekit.plot_tools.plot_tools import create_app, ResultsHandler from mergekit.metric_methods.base import Results +from pathlib import Path -@click.command() -@click.option('--input_dir', - default="/Users/elliotstein/Documents/Arcee/mergekit/representations/results_to_visualise", - help="path to load the results from.") def main(input_dir): handler = ResultsHandler() for res in Path(input_dir).iterdir(): @@ -18,4 +15,7 @@ def main(input_dir): app.run_server() if __name__ == '__main__': - main() \ No newline at end of file + mergekit_root = Path(__file__).parent.parent + input_dir = mergekit_root / 'representations' / 'stored_results' + + main(input_dir) \ No newline at end of file From 60726d4f12be3acbf0ba40bd553a3d2b916d70c9 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Thu, 1 Aug 2024 22:26:11 +0100 Subject: [PATCH 57/64] Bug fixes --- mergekit/metric_methods/aggregator_metrics.py | 18 ++--- mergekit/plot_tools/plot_tools.py | 63 +++++++---------- representations/experiment_setup.py | 67 ++++++++++++------- representations/representation_metrics.py | 4 +- .../visualise_representation_results.py | 2 +- 5 files changed, 81 insertions(+), 73 deletions(-) diff --git a/mergekit/metric_methods/aggregator_metrics.py b/mergekit/metric_methods/aggregator_metrics.py index a1f76e19..818b016b 100644 --- a/mergekit/metric_methods/aggregator_metrics.py +++ b/mergekit/metric_methods/aggregator_metrics.py @@ -20,8 +20,8 @@ class ModelAnalysisType(enum.Enum): class LayerComparisonType(enum.Enum): SINGLE = "single" # Analyse Layer i BLOCK = "block" # Compare Layer i in model 1 with layer i+(block size) in model 1 - CORRESPONDING_LAYERS = "corresponding" # Compare Layer i in model 1 with layer i in model 2 - ALL_LAYERS = "all_layers" # Compare Layer i in model 1 with Layer j in model (1 or 2) + CORRESPONDING = "corresponding" # Compare Layer i in model 1 with layer i in model 2 + ALL = "all_layers" # Compare Layer i in model 1 with Layer j in model (1 or 2) class MetricAggregator(): def __init__(self, device: str = "cpu"): @@ -29,8 +29,8 @@ def __init__(self, device: str = "cpu"): self.valid_for = { LayerComparisonType.SINGLE.value: False, LayerComparisonType.BLOCK.value: False, - LayerComparisonType.CORRESPONDING_LAYERS.value: False, - LayerComparisonType.ALL_LAYERS.value: False + LayerComparisonType.CORRESPONDING.value: False, + LayerComparisonType.ALL.value: False } def process_batch(self, batch_a: torch.Tensor, batch_b: Optional[torch.Tensor]) -> None: @@ -48,8 +48,8 @@ def __init__(self, device: str = "cpu"): self.cosine_similarities = torch.tensor([], device=self.device) self.valid_for.update({ LayerComparisonType.BLOCK.value: True, - LayerComparisonType.CORRESPONDING_LAYERS.value: True, - LayerComparisonType.ALL_LAYERS.value: True + LayerComparisonType.CORRESPONDING.value: True, + LayerComparisonType.ALL.value: True }) def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: @@ -82,7 +82,7 @@ def __init__(self, device: str = "cpu"): self.square_errors = torch.tensor([], device=self.device) self.valid_for.update({ LayerComparisonType.BLOCK.value: True, - LayerComparisonType.CORRESPONDING_LAYERS.value: True, + LayerComparisonType.CORRESPONDING.value: True, }) def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: @@ -223,8 +223,8 @@ def __init__(self, device: str = "cpu"): self.valid_for.update({ LayerComparisonType.BLOCK.value: True, - LayerComparisonType.CORRESPONDING_LAYERS.value: True, - LayerComparisonType.ALL_LAYERS.value: True + LayerComparisonType.CORRESPONDING.value: True, + LayerComparisonType.ALL.value: True }) def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 38f93546..811ef0df 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -11,6 +11,7 @@ from mergekit.metric_methods.all_metrics import Layer from mergekit.metric_methods.base import Results, PlotType from mergekit.common import ModelReference +from plotly.subplots import make_subplots global_colours_list = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] global_shapes_list = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] @@ -39,11 +40,11 @@ def load_results(self, results: Results): for plot_type in self.available_layer_plots.keys(): - # if self.inter_model_results is not None: - self.available_layer_plots[plot_type] += list(self.inter_model_results.available_plot_types(plot_type).keys()) - # if self.inter_model_results is not None: - for model_path, results in self.intra_model_results.items(): - self.available_layer_plots[plot_type] += list(results.available_plot_types(plot_type).keys()) + if self.inter_model_results is not None: + self.available_layer_plots[plot_type] += list(self.inter_model_results.available_plot_types(plot_type).keys()) + if self.intra_model_results is not None: + for model_path, results in self.intra_model_results.items(): + self.available_layer_plots[plot_type] += list(results.available_plot_types(plot_type).keys()) self.available_layer_plots[plot_type] = list(set(self.available_layer_plots[plot_type])) @@ -132,7 +133,7 @@ def plotly_layer_plot(self, layer_name:str, metric_name:str, plot_type:str): return self.get_traces(data, plot_type) # Can prob use type of data to determine plot type (X) - def get_traces(self, data:List, plot_type): + def get_traces(self, data:List, plot_type:str): # Can prob use type of data to determine plot type (X) if plot_type == PlotType.HEATMAP.value: traces = [go.Heatmap( z=d.data, @@ -195,10 +196,11 @@ def layer_plot_options(self, layer_name: str): for metric in self.inter_model_results.layers[layer_name].metrics_with_attribute(plot_type.value)] ) for result in self.all_results: - metric_options.extend([ - {"label": f"{metric.title()} {plot_type.value}", "value": [metric, plot_type.value]} - for metric in result.layers[layer_name].metrics_with_attribute(plot_type.value)] - ) + if layer_name in result.layers: + metric_options.extend([ + {"label": f"{metric.title()} {plot_type.value}", "value": [metric, plot_type.value]} + for metric in result.layers[layer_name].metrics_with_attribute(plot_type.value)] + ) break # Assuming all intra-model results have the same metrics return metric_options @@ -251,13 +253,13 @@ def create_across_layers_section(results_handler): plot_sections = [] for result in results: - if hasattr(result, 'across_layer_metrics'): + if getattr(result, 'across_layer_metrics', None): for metric_name, metric in result.across_layer_metrics.items(): - for attr in ['histogram', 'heatmap', 'scatter_plot']: - if hasattr(metric, attr): + for plot_type in ['histogram', 'heatmap', 'scatter_plot']: + if getattr(metric, plot_type, None): #(X) shouldn't need [0] - metric is being stored inside an array and shouldn't be! plot_sections.append(html.Div([ - html.H3(f'{attr+metric_name.replace("_", " ").title()} {attr.replace("_", " ").title()}', style={'textAlign': 'center'}), - dcc.Graph(id=f'{attr}-plot-{metric_name}', style={'width': '50%', 'height': '50%', 'position': 'relative'}) + html.H3(f'{plot_type+metric_name.replace("_", " ").title()} {plot_type.replace("_", " ").title()}', style={'textAlign': 'center'}), + dcc.Graph(id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}", style={'width': '50%', 'height': '50%', 'position': 'relative'}) ], className='container-fluid')) return html.Div(plot_sections) @@ -317,16 +319,7 @@ def display_layer_data(selected_metric, clickData): xaxis_title = "Value" yaxis_title = "Count" - plot_function = { - 'histogram': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.HISTOGRAM.value), - 'heatmap': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.HEATMAP.value), - 'scatter_plot': lambda layer_name, metric_name: results_handler.plotly_layer_plot(layer_name, metric_name, PlotType.SCATTER_PLOT.value) - } - - traces = plot_function[plot_type.lower()]( - layer_name=layer_name, - metric_name=metric_name - ) + traces = results_handler.plotly_layer_plot(layer_name, metric_name, plot_type) return create_figure(traces=traces, title=f"{plot_type.title()} for {layer_name} | {metric_name}", @@ -365,32 +358,26 @@ def update_line_plot(selected_metric): return fig for result in results_handler.all_results: - if hasattr(result, 'across_layer_metrics'): + if getattr(result, 'across_layer_metrics', None): for metric_name, metric in result.across_layer_metrics.items(): for plot_type in ['histogram', 'heatmap', 'scatter_plot']: - if hasattr(metric, plot_type): - id=f'{plot_type}-plot-{metric_name}' + if getattr(metric, plot_type, None): #(X) shouldn't need [0] - metric is being stored inside an array and shouldn't be! + id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}" @app.callback( Output(id, 'figure'), Input(id, 'id') ) - def update_across_layers_plot(_id=id): - metric_name = _id.split('-')[-1] - plot_function = { - 'histogram': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.HISTOGRAM), - 'heatmap': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.HEATMAP), - 'scatter_plot': lambda metric_name: results_handler.plotly_layer_plot(metric_name, PlotType.SCATTER_PLOT) - }.get(plot_type, - lambda *args, **kwargs: go.Figure()) - traces = plot_function(metric_name=metric_name) + def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric): + traces = results_handler.get_traces(data = [getattr(metric, plot_type)], plot_type = plot_type) return create_figure(traces=traces, title=f"{id} | {metric_name}", + xaxis_title="Temp title", + yaxis_title=metric_name, plot_type = plot_type ) -from plotly.subplots import make_subplots def create_figure(traces, title, xaxis_title, yaxis_title, plot_type): if plot_type in ["scatter_plot", "heatmap"]: num_plots = len(traces) diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py index 671b0ae0..7eece887 100644 --- a/representations/experiment_setup.py +++ b/representations/experiment_setup.py @@ -12,10 +12,35 @@ from mergekit.metric_methods.base import Results, Layer from mergekit.metric_methods.aggregator_metrics import ModelAnalysisType, LayerComparisonType, METRICS_TABLE, MetricAggregator -def check_memory(h5file): +from mergekit.metric_methods.base import Heatmap, Metric + +def check_memory(h5file, h5file_2=None): # Check if full data can be loaded into memory + # Not yet implemented (X) return True +def convert_to_2d_array(arrays): + """ + Convert a list of 1D numpy arrays into a single 2D array. + + Parameters: + arrays (list of np.ndarray): List of 1D numpy arrays. + + Returns: + np.ndarray: 2D numpy array with dimensions N x max_length, filled with NaNs where necessary. + """ + # Determine the length of the longest array + max_length = max(len(arr) for arr in arrays) + + # Create an empty 2D array filled with NaNs + result = np.full((len(arrays), max_length), np.nan) + + # Populate the 2D array with the values from the 1D arrays + for i, arr in enumerate(arrays): + result[i, :len(arr)] = arr + + return result + class LayerByIndex: def __init__(self, reps_path: str, load_into_memory: bool = True): self.reps_path = reps_path @@ -23,7 +48,8 @@ def __init__(self, reps_path: str, load_into_memory: bool = True): self.layers = None self.load_into_memory = load_into_memory self.in_memory_data = None - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = 'cuda' if torch.cuda.is_available() else \ + 'mps' if torch.backends.mps.is_available() else 'cpu' def __enter__(self): self.representations = h5py.File(self.reps_path, 'r') @@ -108,8 +134,6 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd desc='Comparing Representations at layer', total=len(representations_0), initial = 1): - # layer_0_name = layer_0.name.split('/')[-1] - # layer_1_name = layer_1.name.split('/')[-1] if layer_0_name != layer_1_name: raise ValueError(f'Layer mismatch: {layer_0_name} != {layer_1_name}') num = f"{int(layer_0_name.split('_')[-1]):03d}" @@ -119,8 +143,6 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd for batch_0, batch_1 in tqdm(zip(layer_0.values(), layer_1.values()), desc='Batch', total=len(layer_0), leave=False, initial = 1): - # batch_0 = torch.tensor(layer_0[batch_0][:], device=device) - # batch_1 = torch.tensor(layer_1[batch_1][:], device=device) # Calculate the metrics for each batch for metric in metrics: @@ -129,8 +151,7 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd layer_results = Layer(WeightInfo(name=layer_name)) # Aggregate over the batches and add to the layer results for metric in metrics: - layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) - # metric.clear() + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) # (X) results.add_layer(layer_results, layer_name) @@ -149,9 +170,6 @@ def block(representations, block_size, metric_classes, device='cpu'): for batch_0, batch_1 in tqdm(zip(block_start.values(), block_end.values()), desc='Batch', total=len(block_start), leave=False, initial = 1): - # batch_0 = torch.tensor(block_start[batch_0][:]).to(device) - # batch_1 = torch.tensor(block_end[batch_1][:]).to(device) - for metric in metrics: metric.process_batch(batch_0, batch_1) @@ -159,8 +177,8 @@ def block(representations, block_size, metric_classes, device='cpu'): for metric in metrics: out[metric.__class__.__name__.lower()].append(metric.aggregate()) - for metric in out: - out[metric] = np.array(out[metric]) + for metric_name, metric in out.items(): + out[metric_name] = np.array([m.mean_std.mean for m in out[metric_name]]) return out def all_layers(representations_0, representations_1, metric_classes, results, device='cpu'): @@ -175,8 +193,6 @@ def all_layers(representations_0, representations_1, metric_classes, results, de for batch_0, batch_1 in tqdm(zip(layer_0.values(), layer_1.values()), desc='Batch', total=len(layer_0), leave=False, initial = 1): - # batch_0 = torch.tensor(layer_0[batch_0][:]).to(device) - # batch_1 = torch.tensor(layer_1[batch_1][:]).to(device) for metric in metrics: metric.process_batch(batch_0, batch_1) @@ -202,7 +218,7 @@ class Configuration: @classmethod def from_dict(cls, config_dict: Dict): return cls( - representation_paths=list(config_dict['stored_representations'].iterdir()), + representation_paths=list([path for path in config_dict['representations_to_analyse'].iterdir() if path.suffix == '.h5']), metrics=config_dict['metrics'], analysis_type=ModelAnalysisType(config_dict['analysis_type']), comparison_type=LayerComparisonType(config_dict['comparison_type']), @@ -282,26 +298,29 @@ def run(self, config: Configuration): if enabled} metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] - heatmaps = {} for representation_path in config.representation_paths: + heatmaps = {} block_results = Results() block_results.model_paths = [representation_path] if not representation_path.exists(): raise FileNotFoundError(f"Representation file {representation_path} not found") for metric in metrics: - heatmaps[metric().__class__.__name__.lower()] = np.array([]) + heatmaps[metric().__class__.__name__.lower()] = [] with LayerByIndex(representation_path, load_into_memory = check_memory(representation_path)) as representations: - for block_size in range(1, 9): + for block_size in range(1, 9): # (X) block_res = block(representations=representations, block_size=block_size, metric_classes=metrics, device=config.device) for metric in metrics: - heatmaps[metric().__class__.__name__.lower()] = np.append(heatmaps[metric().__class__.__name__.lower()], block_res[metric().__class__.__name__.lower()]) + heatmaps[metric().__class__.__name__.lower()].append(block_res[metric().__class__.__name__.lower()]) for metric in metrics: - block_results.across_layer_metrics[metric().__class__.__name__.lower()] = heatmaps[metric().__class__.__name__.lower()] # Definitely a simpler way to code this (X) - + block_results.across_layer_metrics[metric().__class__.__name__.lower()] = Metric( + heatmap = Heatmap( + data = convert_to_2d_array(heatmaps[metric().__class__.__name__.lower()]) # Definitely a simpler way to code this (X) + ) + ) out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{config.analysis_type}+{config.comparison_type}.json" block_results.save(out_path) @@ -320,7 +339,9 @@ def run(self, config: Configuration): (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON) and not stop): if rep_0 != rep_1: stop = True - with LayerByIndex(rep_0) as representations_0, LayerByIndex(rep_1) as representations_1: + load_into_memory = check_memory(rep_0, rep_1) + with LayerByIndex(rep_0, load_into_memory) as representations_0, \ + LayerByIndex(rep_1, load_into_memory) as representations_1: comparison_results = all_layers(representations_0=representations_0, representations_1=representations_1, diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 675b2f58..0bbb6404 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -9,14 +9,14 @@ def run(config_yml: str = "config.yml"): mergekit_root = Path(__file__).parent.parent config = yaml.safe_load(open(mergekit_root / 'representations' / 'configs' / config_yml, 'r')) config['out_dir'] = mergekit_root / 'representations' / 'stored_results' - config['stored_representations'] = mergekit_root / 'representations' / 'stored_representations' + config['representations_to_analyse'] = mergekit_root / 'representations' / 'representations_to_analyse' config = Configuration.from_dict(config) experiment = ExperimentFactory.create(config.comparison_type.name.lower()) experiment.run(config) @click.command() -@click.option('--config_yml', default="config_i_single.yml", help='path to the configuration file.') +@click.option('--config_yml', default="config_i_block.yml", help='path to the configuration file.') def main(config_yml: str = "config.yml"): run(config_yml) diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py index 9119be2f..2e26114a 100644 --- a/representations/visualise_representation_results.py +++ b/representations/visualise_representation_results.py @@ -16,6 +16,6 @@ def main(input_dir): if __name__ == '__main__': mergekit_root = Path(__file__).parent.parent - input_dir = mergekit_root / 'representations' / 'stored_results' + input_dir = mergekit_root / 'representations' / 'results_to_visualise' main(input_dir) \ No newline at end of file From 891bbb460487d161e43be3114f5005f6ec431fa6 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Fri, 2 Aug 2024 15:50:30 +0100 Subject: [PATCH 58/64] visualisation fixes --- mergekit/metric_methods/aggregator_metrics.py | 2 +- mergekit/plot_tools/plot_tools.py | 87 ++++++++++--------- representations/experiment_setup.py | 6 +- 3 files changed, 49 insertions(+), 46 deletions(-) diff --git a/mergekit/metric_methods/aggregator_metrics.py b/mergekit/metric_methods/aggregator_metrics.py index 818b016b..25ef602d 100644 --- a/mergekit/metric_methods/aggregator_metrics.py +++ b/mergekit/metric_methods/aggregator_metrics.py @@ -245,7 +245,7 @@ def __init__(self, device: str = "cpu"): super().__init__(device=device) self.tsne = TSNE(n_components=2, random_state=42) self.batches = [] - self.max_batches = 5 + self.max_batches = 20 self.stop = False self.valid_for.update({ diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 811ef0df..ec9a649c 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -16,6 +16,27 @@ global_colours_list = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] global_shapes_list = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] +def extend_without_duplicates(l1, l2): + """ + Extend list l1 with elements from list l2, ensuring no duplicates. + Assumes the lists contain dictionaries of strings + + Parameters: + l1 (list): The original list to be extended. + l2 (list): The list with elements to add to l1. + + Returns: + list: The extended list without duplicates. + """ + # Convert dictionaries to tuples of sorted items for hashability + l1_set = set(tuple(sorted(item.items())) for item in l1) + for item in l2: + item_tuple = tuple(sorted(item.items())) + if item_tuple not in l1_set: + l1.append(item) + l1_set.add(item_tuple) + return l1 + class ResultsHandler: def __init__(self): self.intra_model_results: Dict[ModelReference, Results] = {} @@ -166,45 +187,23 @@ def get_traces(self, data:List, plot_type:str): # Can prob use type of data to d return traces - def _set_plot_attributes(self, ax, stat: str, ax_kwargs: List[str], **kwargs): - """ - Set the attributes of the plot. - - Args: - ax: The matplotlib Axes object. - stat (str): The name of the stat. - **kwargs: Additional keyword arguments for plot attributes. - """ - # Defaults - ax.set_ylabel(kwargs.get('ylabel', stat)) - ax.set_xticks(np.arange(len(self.layer_names))) - ax.set_xticklabels(self.layer_names, rotation=45) - ax.set_title(kwargs.get('title', f'{stat.replace("_", " ").title()}')) - - # Set additional attributes - for kwarg in ax_kwargs: - if kwarg in kwargs: - getattr(ax, f"set_{kwarg}")(kwargs[kwarg]) - def layer_plot_options(self, layer_name: str): metric_options = [] + avoid_duplicates = [] for plot_type in PlotType: if plot_type == PlotType.MEAN_STD: continue - metric_options.extend([ - {"label": f"{metric.title()} {plot_type.value}", "value": [metric, plot_type.value]} - for metric in self.inter_model_results.layers[layer_name].metrics_with_attribute(plot_type.value)] - ) for result in self.all_results: if layer_name in result.layers: - metric_options.extend([ - {"label": f"{metric.title()} {plot_type.value}", "value": [metric, plot_type.value]} - for metric in result.layers[layer_name].metrics_with_attribute(plot_type.value)] - ) - break # Assuming all intra-model results have the same metrics + metrics = result.layers[layer_name].metrics_with_attribute(plot_type.value) + for metric in metrics: + if (metric, plot_type) not in avoid_duplicates: + metric_options.append({f"label": f"{metric.title()} {plot_type.value}", + "value": [metric, plot_type.value] + }) + avoid_duplicates.append((metric, plot_type)) return metric_options - def create_app(results_handler): app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) @@ -256,9 +255,9 @@ def create_across_layers_section(results_handler): if getattr(result, 'across_layer_metrics', None): for metric_name, metric in result.across_layer_metrics.items(): for plot_type in ['histogram', 'heatmap', 'scatter_plot']: - if getattr(metric, plot_type, None): #(X) shouldn't need [0] - metric is being stored inside an array and shouldn't be! + if getattr(metric, plot_type, None): plot_sections.append(html.Div([ - html.H3(f'{plot_type+metric_name.replace("_", " ").title()} {plot_type.replace("_", " ").title()}', style={'textAlign': 'center'}), + html.H3(f'{metric_name.replace("_", " ").title()} {plot_type.replace("_", " ").title()}', style={'textAlign': 'center'}), dcc.Graph(id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}", style={'width': '50%', 'height': '50%', 'position': 'relative'}) ], className='container-fluid')) @@ -312,7 +311,7 @@ def display_layer_data(selected_metric, clickData): metric_name, plot_type = selected_metric - if plot_type.lower() in ["heatmap", "scatter_plot"]: + if plot_type.lower() in ["heatmap", "scatter_plot"]: #(x) xaxis_title = "Model 1" yaxis_title = "Model 0" elif plot_type.lower() == 'histogram': @@ -362,29 +361,35 @@ def update_line_plot(selected_metric): for metric_name, metric in result.across_layer_metrics.items(): for plot_type in ['histogram', 'heatmap', 'scatter_plot']: if getattr(metric, plot_type, None): #(X) shouldn't need [0] - metric is being stored inside an array and shouldn't be! - id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}" + id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}" # Improve naming scheme, this could get confused with comparison results @app.callback( Output(id, 'figure'), Input(id, 'id') ) - def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric): + def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric, mode_paths=result.model_paths): traces = results_handler.get_traces(data = [getattr(metric, plot_type)], plot_type = plot_type) return create_figure(traces=traces, title=f"{id} | {metric_name}", - xaxis_title="Temp title", - yaxis_title=metric_name, + xaxis_title="Layer Model 0", + yaxis_title=f"Layer Model {0 if len(mode_paths) == 1 else 1}", plot_type = plot_type ) -def create_figure(traces, title, xaxis_title, yaxis_title, plot_type): +def create_figure(traces, xaxis_title, yaxis_title, plot_type, title=None, subplot_titles=[None]): if plot_type in ["scatter_plot", "heatmap"]: num_plots = len(traces) - num_cols = 2 - num_rows = (num_plots + 1) // num_cols + + num_cols = 2 if num_plots > 1 else 1 + + num_rows = (num_plots + 1) // num_cols if num_plots > 1 else 1 - fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[f"Plot {i+1}" for i in range(num_plots)]) + fig = make_subplots(rows=num_rows, + cols=num_cols, + subplot_titles=subplot_titles + if subplot_titles + else [f"Plot {i+1}" for i in range(num_plots)]) for i, trace in enumerate(traces): row = (i // num_cols) + 1 diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py index 7eece887..1829f379 100644 --- a/representations/experiment_setup.py +++ b/representations/experiment_setup.py @@ -136,8 +136,6 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd if layer_0_name != layer_1_name: raise ValueError(f'Layer mismatch: {layer_0_name} != {layer_1_name}') - num = f"{int(layer_0_name.split('_')[-1]):03d}" - layer_name = f"Layer {num}" metrics = [metric_class(device=device) for metric_class in metric_classes] @@ -148,12 +146,12 @@ def corresponding(representations_0: LayerByIndex, representations_1: LayerByInd for metric in metrics: metric.process_batch(batch_0, batch_1) - layer_results = Layer(WeightInfo(name=layer_name)) + layer_results = Layer(WeightInfo(name=layer_0_name)) # Aggregate over the batches and add to the layer results for metric in metrics: layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) # (X) - results.add_layer(layer_results, layer_name) + results.add_layer(layer_results, layer_0_name) return results From 04b64e86fd0d5b96d40ef28e3ccdf1fad491ee42 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Fri, 2 Aug 2024 18:11:18 +0100 Subject: [PATCH 59/64] address model path naming issue (but only for representations! Not model weights) --- mergekit/_data/models_and_datasets.json | 5 +++ mergekit/_data/models_and_datasets.py | 45 +++++++++++++++++++++ mergekit/metric_methods/base.py | 33 +++++++++++++-- mergekit/plot_tools/plot_tools.py | 49 +++++++---------------- representations/experiment_setup.py | 22 +++++----- representations/representation_metrics.py | 2 +- representations/store_representations.py | 34 +++++++++------- 7 files changed, 124 insertions(+), 66 deletions(-) create mode 100644 mergekit/_data/models_and_datasets.json create mode 100644 mergekit/_data/models_and_datasets.py diff --git a/mergekit/_data/models_and_datasets.json b/mergekit/_data/models_and_datasets.json new file mode 100644 index 00000000..63aa2cf5 --- /dev/null +++ b/mergekit/_data/models_and_datasets.json @@ -0,0 +1,5 @@ +{ + "models": [], + + "datasets": [] +} \ No newline at end of file diff --git a/mergekit/_data/models_and_datasets.py b/mergekit/_data/models_and_datasets.py new file mode 100644 index 00000000..6ebfda65 --- /dev/null +++ b/mergekit/_data/models_and_datasets.py @@ -0,0 +1,45 @@ +import json +from pathlib import Path + +def save_model_and_dataset(model_name, dataset_name): + models_and_datasets = presets().load() + if model_name not in models_and_datasets['model_names']: + models_and_datasets['model_names'].append(model_name) + if dataset_name not in models_and_datasets['dataset_names']: + models_and_datasets['dataset_names'].append(dataset_name) + presets().save(models_and_datasets['model_names'], models_and_datasets['dataset_names']) + +def model_and_dataset_to_index(model_name, dataset_name): + models_and_datasets = presets().load() + model_index = models_and_datasets['model_names'].index(model_name) if model_name in models_and_datasets['model_names'] else [] + dataset_index = models_and_datasets['dataset_names'].index(dataset_name) if dataset_name in models_and_datasets['dataset_names'] else [] + + return model_index, dataset_index + +def index_to_model_and_dataset(model_index, dataset_index): + models_and_datasets = presets().load() + model_name = models_and_datasets['model_names'][model_index] if len(models_and_datasets['model_names']) > model_index else [] + dataset_name = models_and_datasets['dataset_names'][dataset_index] if len(models_and_datasets['dataset_names']) > dataset_index else [] + return model_name, dataset_name + +class presets(): + def __init__(self): + self.FILE_PATH = Path(__file__).parent / 'models_and_datasets.json' + + def load(self): + """Load the lists from the JSON file.""" + if self.FILE_PATH.exists(): + with open(self.FILE_PATH, 'r') as file: + data = json.load(file) + return data + print(f"File {self.FILE_PATH} does not exist or is empty.") + return {} + + def save(self, model_names, dataset_names): + """Save the lists to the JSON file.""" + data = { + 'model_names': model_names, + 'dataset_names': dataset_names + } + with open(self.FILE_PATH, 'w') as file: + json.dump(data, file, indent=4) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 7a912158..5c429af4 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -122,25 +122,39 @@ def expand_to_fit(all_layer_names: List[str], values: List[float], subset_layer_ from typing import List, Tuple from mergekit.graph import Task +from mergekit._data.models_and_datasets import save_model_and_dataset, model_and_dataset_to_index, index_to_model_and_dataset class Results: # Class to store the statistics for each layer def __init__(self): self.layers: Dict[str, Layer] = {} self.across_layer_metrics: Dict[str, Metric] = {} - self.model_paths: Optional[List[str]] = None + # self.model_paths: Optional[List[str]] = None + self.representations_details: Optional[List[Tuple]] = [] # List of tuples of (model_name, dataset_name) def add_layer(self, layer: Layer, name: str): if name not in self.layers.keys(): self.layers[name] = layer - def load_metrics(self, metrics: List[Tuple[Task, Layer]], model_paths: Optional[List[str]] = None): + def load_metrics(self, metrics: List[Tuple[Task, Layer]], model_paths: Optional[List[str]] = None): # Maybe remove (#) self.model_paths = model_paths for task, metric in metrics: if metric is not None: self.add_layer(metric, name=task.weight_info.name) return self + def load_representations_details_from_path(self, representations_path: str): + representations_path = Path(representations_path) + if representations_path.exists() and representations_path.is_file(): + file_name = representations_path.name + assert file_name.endswith('.h5'), f"File {file_name} is not an HDF5 file." + assert len(file_name.split('_')) == 3, f"File {file_name} does not follow the naming convention." + + model_name, dataset_name = index_to_model_and_dataset(*file_name.split('_')[:2]) + assert model_and_dataset_to_index(model_name, dataset_name) != ([], []), f"Model and dataset {model_name, dataset_name} not found in presets." + + self.representations_details.append((model_name, dataset_name)) + def get_lineplot_data(self, metric_name: str): means, stds = [],[] @@ -217,8 +231,19 @@ def finalise(self): self.layer_names = list(self.layers.keys()) self.metric_names = list(set([metric for layer in self.layers.values() for metric in layer.metrics.keys()])) - def save(self, path: str): - path = Path(path) + def save(self, out_dir: str, suffix: Optional[str] = None): + out_dir = Path(out_dir) + + file_name = '' + for i, (model_name, dataset_name) in enumerate(self.representations_details): + m_idx, d_idx = model_and_dataset_to_index(model_name, dataset_name) + file_name += f"details_{i}__{m_idx}_{d_idx}" + + if suffix: + file_name += f"_{suffix}" + + path = out_dir / f"{file_name}.pkl" + if not path.suffix or path.suffix != '.pkl': path = path.with_suffix('.pkl') diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index ec9a649c..5cb919e6 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -16,27 +16,6 @@ global_colours_list = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] global_shapes_list = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] -def extend_without_duplicates(l1, l2): - """ - Extend list l1 with elements from list l2, ensuring no duplicates. - Assumes the lists contain dictionaries of strings - - Parameters: - l1 (list): The original list to be extended. - l2 (list): The list with elements to add to l1. - - Returns: - list: The extended list without duplicates. - """ - # Convert dictionaries to tuples of sorted items for hashability - l1_set = set(tuple(sorted(item.items())) for item in l1) - for item in l2: - item_tuple = tuple(sorted(item.items())) - if item_tuple not in l1_set: - l1.append(item) - l1_set.add(item_tuple) - return l1 - class ResultsHandler: def __init__(self): self.intra_model_results: Dict[ModelReference, Results] = {} @@ -48,23 +27,22 @@ def __init__(self): 'scatter_plot': [] } - def load_results(self, results: Results): + def load_results(self, results: Results): # Generalise to handle both representations and model weights results.finalise() - if len(results.model_paths) == 2: + if len(results.representations_details) == 2: self.inter_model_results = results - elif len(results.model_paths) == 1: - # key = results.model_paths[0] + elif len(results.representations_details) == 1: key = len(self.intra_model_results) self.intra_model_results[key] = results else: - raise ValueError("Results should have either 1 or 2 model_paths") + raise ValueError(f"Results should have either 1 or 2 inputs, got {len(results.representations_details)}") for plot_type in self.available_layer_plots.keys(): if self.inter_model_results is not None: self.available_layer_plots[plot_type] += list(self.inter_model_results.available_plot_types(plot_type).keys()) if self.intra_model_results is not None: - for model_path, results in self.intra_model_results.items(): + for model_id, results in self.intra_model_results.items(): self.available_layer_plots[plot_type] += list(results.available_plot_types(plot_type).keys()) self.available_layer_plots[plot_type] = list(set(self.available_layer_plots[plot_type])) @@ -96,11 +74,11 @@ def plotly_line_plots(self, metric_name:str): else: unique_categories = list(self.intra_model_results.keys()) - for i, (model_path, results) in enumerate(self.intra_model_results.items()): + for i, (model_id, results) in enumerate(self.intra_model_results.items()): layer_names = results.layer_names means, stds = results.get_lineplot_data(metric_name) if means: - categorised_layers = [model_path]*len(layer_names) # Different category for each model, every layer in each model has the same category + categorised_layers = [model_id]*len(layer_names) # Different category for each model, every layer in each model has the same category shape = global_shapes_list[i%len(global_shapes_list)] traces.extend(self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories, shape)) @@ -258,7 +236,7 @@ def create_across_layers_section(results_handler): if getattr(metric, plot_type, None): plot_sections.append(html.Div([ html.H3(f'{metric_name.replace("_", " ").title()} {plot_type.replace("_", " ").title()}', style={'textAlign': 'center'}), - dcc.Graph(id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}", style={'width': '50%', 'height': '50%', 'position': 'relative'}) + dcc.Graph(id=f"{plot_type}-plot-{metric_name}-{''.join(item for tup in result.representations_details for item in tup)}", style={'width': '50%', 'height': '50%', 'position': 'relative'}) ], className='container-fluid')) return html.Div(plot_sections) @@ -361,19 +339,20 @@ def update_line_plot(selected_metric): for metric_name, metric in result.across_layer_metrics.items(): for plot_type in ['histogram', 'heatmap', 'scatter_plot']: if getattr(metric, plot_type, None): #(X) shouldn't need [0] - metric is being stored inside an array and shouldn't be! - id=f"{plot_type}-plot-{metric_name}-{str(result.model_paths[0].name).split('__')[-1].split('.')[0]}" # Improve naming scheme, this could get confused with comparison results + id=f"{plot_type}-plot-{metric_name}-{''.join(item for tup in result.representations_details for item in tup)}" # Improve naming scheme, this could get confused with comparison results @app.callback( Output(id, 'figure'), Input(id, 'id') ) - def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric, mode_paths=result.model_paths): + def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric, rep_details=result.representations_details): traces = results_handler.get_traces(data = [getattr(metric, plot_type)], plot_type = plot_type) - + xaxis_title = f"Model {rep_details[0][0]} {rep_details[0][1]}" + yaxis_title = f"Model {rep_details[1][0]} {rep_details[1][1]}" if len(rep_details) > 1 else xaxis_title return create_figure(traces=traces, title=f"{id} | {metric_name}", - xaxis_title="Layer Model 0", - yaxis_title=f"Layer Model {0 if len(mode_paths) == 1 else 1}", + xaxis_title=xaxis_title + yaxis_title=xaxis_title plot_type = plot_type ) diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py index 1829f379..177332c0 100644 --- a/representations/experiment_setup.py +++ b/representations/experiment_setup.py @@ -248,7 +248,7 @@ def run(self, config: Configuration): metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] for representation_path in config.representation_paths: individual_results = Results() - individual_results.model_paths = [representation_path] + individual_results.representations_details_from_path(representation_path) if not representation_path.exists(): raise FileNotFoundError(f"Representation file {representation_path} not found") @@ -259,8 +259,7 @@ def run(self, config: Configuration): results=individual_results, device=config.device) - out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{config.analysis_type}+{config.comparison_type}.json" - individual_results.save(out_path) + individual_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") class CorrespondingExperiment(Experiment): def run(self, config: Configuration): @@ -284,10 +283,10 @@ def run(self, config: Configuration): metric_classes=metrics, results=comparison_results, device=config.device) - comparison_results.model_paths = [rep_0, rep_1] if rep_0 != rep_1 else [rep_0] + comparison_results.load_representations_details_from_path(rep_0) + comparison_results.load_representations_details_from_path(rep_1) - out_path = config.out_dir / f"{str([str(rep).split('/')[-1].split('.')[0] for rep in [rep_0, rep_1]])}+{config.analysis_type}+{config.comparison_type}.json" - comparison_results.save(out_path) + comparison_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") class BlockExperiment(Experiment): def run(self, config: Configuration): @@ -299,7 +298,7 @@ def run(self, config: Configuration): for representation_path in config.representation_paths: heatmaps = {} block_results = Results() - block_results.model_paths = [representation_path] + block_results.representations_details_from_path(representation_path) if not representation_path.exists(): raise FileNotFoundError(f"Representation file {representation_path} not found") for metric in metrics: @@ -319,8 +318,7 @@ def run(self, config: Configuration): data = convert_to_2d_array(heatmaps[metric().__class__.__name__.lower()]) # Definitely a simpler way to code this (X) ) ) - out_path = config.out_dir / f"{str(representation_path).split('/')[-1].split('.')[0]}+{config.analysis_type}+{config.comparison_type}.json" - block_results.save(out_path) + block_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") class AllLayersExperiment(Experiment): def run(self, config: Configuration): @@ -346,10 +344,10 @@ def run(self, config: Configuration): metric_classes=metrics, results=comparison_results, device=config.device) - comparison_results.model_paths = [rep_0, rep_1] + comparison_results.load_representations_details_from_path(rep_0) + comparison_results.load_representations_details_from_path(rep_1) - out_path = config.out_dir / f"{str([str(rep).split('/')[-1].split('.')[0] for rep in config.representation_paths])}+{config.analysis_type}+{config.comparison_type}.json" - comparison_results.save(out_path) + comparison_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") class ExperimentFactory: diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 0bbb6404..83fd1a82 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -8,7 +8,7 @@ def run(config_yml: str = "config.yml"): mergekit_root = Path(__file__).parent.parent config = yaml.safe_load(open(mergekit_root / 'representations' / 'configs' / config_yml, 'r')) - config['out_dir'] = mergekit_root / 'representations' / 'stored_results' + config['out_dir'] = mergekit_root / 'representations' / 'results_out' config['representations_to_analyse'] = mergekit_root / 'representations' / 'representations_to_analyse' config = Configuration.from_dict(config) diff --git a/representations/store_representations.py b/representations/store_representations.py index 75ac099f..f9593f4e 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -21,12 +21,14 @@ from typing import List import random -def load_batch_from_hdf5(model_name, batch_idx): - with h5py.File('batches.h5', 'r') as h5file: - dataset_name = f'{model_name}/batch_{batch_idx}' - batch_data = h5file[dataset_name][:] - batch_tensor = torch.tensor(batch_data) - return batch_tensor +from mergekit._data.models_and_datasets import save_model_and_dataset, model_and_dataset_to_index + +# def load_batch_from_hdf5(dataset_name, batch_idx): +# with h5py.File('batches.h5', 'r') as h5file: +# dataset_name = f'{dataset_name}/batch_{batch_idx}' +# batch_data = h5file[dataset_name][:] +# batch_tensor = torch.tensor(batch_data) +# return batch_tensor def set_seed(seed): torch.manual_seed(seed) @@ -49,7 +51,7 @@ def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tens last_non_padded_hidden_states.append(torch.cat(batch_last_tokens, dim=0)) return last_non_padded_hidden_states -def store_representations(model_path, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): +def store_representations(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): device = "cuda" if torch.cuda.is_available() \ else "mps" if torch.backends.mps.is_available() \ @@ -59,11 +61,11 @@ def store_representations(model_path, output_dir, dataset_name, batch_size, max_ if dataset_size: dataset = dataset.select(range(dataset_size)) - model = AutoModelForCausalLM.from_pretrained(model_path, + model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", output_hidden_states=True) - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name) if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token @@ -73,8 +75,12 @@ def store_representations(model_path, output_dir, dataset_name, batch_size, max_ set_seed(42) dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) - - output_name = f'{output_dir}/{model.name_or_path}_{dataset_name}_{dataset_size}.h5'.replace("/","_") + + save_model_and_dataset(model_name, dataset_name) + model_index, dataset_index = model_and_dataset_to_index(model_name, dataset_name) + + output_name = f'{output_dir}/{model_index}_{dataset_index}_id_{np.random.randint(1000)}.h5'.replace("/","_") + assert not os.path.exists(output_name), f'{output_name} already exists.' with h5py.File(output_name, 'w') as h5file: @@ -100,7 +106,7 @@ def store_representations(model_path, output_dir, dataset_name, batch_size, max_ does not match expected number of hidden layers." @click.command() -@click.option('--model_path', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') +@click.option('--model_name', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') @click.option('--output_dir', default="./representations/stored_representations", help='folder to store the result in.') @click.option('--dataset_name', default="arcee-ai/sec-data-mini", help='dataset to use.') @click.option('--batch_size', default=8, help='batch size.') @@ -108,7 +114,7 @@ def store_representations(model_path, output_dir, dataset_name, batch_size, max_ @click.option('--dataset_size', default=4000, help='size of the dataset.') @click.option('--dataset_column', default="text", help='column of the dataset to use.') @click.option('--dataset_subset', default="train", help='subset of the dataset to use.') -def main(model_path, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): - store_representations(model_path, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset) +def main(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + store_representations(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset) if __name__ == "__main__": main() From 1b5a3c04b401ea2f0ba2a551d351c5315d4fe79c Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Mon, 5 Aug 2024 13:55:36 +0100 Subject: [PATCH 60/64] minor fix --- mergekit/_data/models_and_datasets.json | 4 ++-- mergekit/plot_tools/plot_tools.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mergekit/_data/models_and_datasets.json b/mergekit/_data/models_and_datasets.json index 63aa2cf5..51eb2272 100644 --- a/mergekit/_data/models_and_datasets.json +++ b/mergekit/_data/models_and_datasets.json @@ -1,5 +1,5 @@ { - "models": [], + "model_names": [], - "datasets": [] + "dataset_names": [] } \ No newline at end of file diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 5cb919e6..b78479f4 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -351,8 +351,8 @@ def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric, rep_de yaxis_title = f"Model {rep_details[1][0]} {rep_details[1][1]}" if len(rep_details) > 1 else xaxis_title return create_figure(traces=traces, title=f"{id} | {metric_name}", - xaxis_title=xaxis_title - yaxis_title=xaxis_title + xaxis_title=xaxis_title, + yaxis_title=xaxis_title, plot_type = plot_type ) From 6bd12e719037938d8abe0d4d2db3cc3f681e66e4 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 6 Aug 2024 14:26:30 +0100 Subject: [PATCH 61/64] More fixes, end-to-end tested and working --- mergekit/metric_methods/base.py | 14 +++-- representations/experiment_setup.py | 12 ++-- representations/representation_metrics.py | 9 +-- representations/store_representations.py | 71 +++++++++++------------ 4 files changed, 53 insertions(+), 53 deletions(-) diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 5c429af4..12935f7d 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -148,9 +148,10 @@ def load_representations_details_from_path(self, representations_path: str): if representations_path.exists() and representations_path.is_file(): file_name = representations_path.name assert file_name.endswith('.h5'), f"File {file_name} is not an HDF5 file." - assert len(file_name.split('_')) == 3, f"File {file_name} does not follow the naming convention." - - model_name, dataset_name = index_to_model_and_dataset(*file_name.split('_')[:2]) + assert len(file_name.split('_')) == 4, f"File {file_name} does not follow the naming convention: '(model_num)_(dataset_num)_id_(unique_id).h5'" + + model_idx, dataset_idx = int(file_name.split('_')[0]), int(file_name.split('_')[1]) + model_name, dataset_name = index_to_model_and_dataset(model_idx, dataset_idx) assert model_and_dataset_to_index(model_name, dataset_name) != ([], []), f"Model and dataset {model_name, dataset_name} not found in presets." self.representations_details.append((model_name, dataset_name)) @@ -233,11 +234,14 @@ def finalise(self): def save(self, out_dir: str, suffix: Optional[str] = None): out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) - file_name = '' + file_name = 'details_' for i, (model_name, dataset_name) in enumerate(self.representations_details): + if i > 0: + file_name += '_and_' m_idx, d_idx = model_and_dataset_to_index(model_name, dataset_name) - file_name += f"details_{i}__{m_idx}_{d_idx}" + file_name += f"model_{m_idx}_dataset_{d_idx}" if suffix: file_name += f"_{suffix}" diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py index 177332c0..d8108995 100644 --- a/representations/experiment_setup.py +++ b/representations/experiment_setup.py @@ -248,7 +248,7 @@ def run(self, config: Configuration): metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] for representation_path in config.representation_paths: individual_results = Results() - individual_results.representations_details_from_path(representation_path) + individual_results.load_representations_details_from_path(representation_path) if not representation_path.exists(): raise FileNotFoundError(f"Representation file {representation_path} not found") @@ -259,7 +259,7 @@ def run(self, config: Configuration): results=individual_results, device=config.device) - individual_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") + individual_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") class CorrespondingExperiment(Experiment): def run(self, config: Configuration): @@ -286,7 +286,7 @@ def run(self, config: Configuration): comparison_results.load_representations_details_from_path(rep_0) comparison_results.load_representations_details_from_path(rep_1) - comparison_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") + comparison_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") class BlockExperiment(Experiment): def run(self, config: Configuration): @@ -298,7 +298,7 @@ def run(self, config: Configuration): for representation_path in config.representation_paths: heatmaps = {} block_results = Results() - block_results.representations_details_from_path(representation_path) + block_results.load_representations_details_from_path(representation_path) if not representation_path.exists(): raise FileNotFoundError(f"Representation file {representation_path} not found") for metric in metrics: @@ -318,7 +318,7 @@ def run(self, config: Configuration): data = convert_to_2d_array(heatmaps[metric().__class__.__name__.lower()]) # Definitely a simpler way to code this (X) ) ) - block_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") + block_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") class AllLayersExperiment(Experiment): def run(self, config: Configuration): @@ -347,7 +347,7 @@ def run(self, config: Configuration): comparison_results.load_representations_details_from_path(rep_0) comparison_results.load_representations_details_from_path(rep_1) - comparison_results.save(config.out_dir, suffix=f"{config.analysis_type}+{config.comparison_type}") + comparison_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") class ExperimentFactory: diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py index 83fd1a82..10f6e2ad 100644 --- a/representations/representation_metrics.py +++ b/representations/representation_metrics.py @@ -5,9 +5,10 @@ from experiment_setup import Configuration, ExperimentFactory -def run(config_yml: str = "config.yml"): +def run(config_yml: str = "config"): mergekit_root = Path(__file__).parent.parent - config = yaml.safe_load(open(mergekit_root / 'representations' / 'configs' / config_yml, 'r')) + + config = yaml.safe_load(open((mergekit_root / 'representations' / 'configs' / config_yml).with_suffix('.yml'), 'r')) config['out_dir'] = mergekit_root / 'representations' / 'results_out' config['representations_to_analyse'] = mergekit_root / 'representations' / 'representations_to_analyse' config = Configuration.from_dict(config) @@ -16,8 +17,8 @@ def run(config_yml: str = "config.yml"): experiment.run(config) @click.command() -@click.option('--config_yml', default="config_i_block.yml", help='path to the configuration file.') -def main(config_yml: str = "config.yml"): +@click.option('--config_yml', default="config_i_block", help='path to the configuration file.') +def main(config_yml: str = "config"): run(config_yml) if __name__ == "__main__": diff --git a/representations/store_representations.py b/representations/store_representations.py index f9593f4e..5d5a9f81 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -1,34 +1,18 @@ -# WORK IN PROGRESS - import click import h5py - -import logging import numpy as np from tqdm import tqdm - import torch from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer import datasets import os - -logging.basicConfig(level=logging.INFO) - -# Set seed -torch.manual_seed(42) -np.random.seed(42) -from typing import List import random - +from pathlib import Path from mergekit._data.models_and_datasets import save_model_and_dataset, model_and_dataset_to_index - -# def load_batch_from_hdf5(dataset_name, batch_idx): -# with h5py.File('batches.h5', 'r') as h5file: -# dataset_name = f'{dataset_name}/batch_{batch_idx}' -# batch_data = h5file[dataset_name][:] -# batch_tensor = torch.tensor(batch_data) -# return batch_tensor +from typing import List +import gc +import uuid def set_seed(seed): torch.manual_seed(seed) @@ -39,7 +23,6 @@ def set_seed(seed): torch.backends.cudnn.benchmark = False def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tensor]: - """Get last non-padded tokens for each layer.""" last_non_padded_hidden_states = [] for layer in hidden_states: batch_size, _, _ = layer.size() @@ -52,37 +35,42 @@ def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tens return last_non_padded_hidden_states def store_representations(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + # Generate the unique ID using UUID + unique_id = uuid.uuid4().hex[:4] + + #!important: Set seed for consistent batch order across runs + set_seed(42) + save_model_and_dataset(model_name, dataset_name) + model_index, dataset_index = model_and_dataset_to_index(model_name, dataset_name) + output_name = Path(output_dir) / f'{model_index}_{dataset_index}_id_{unique_id}.h5'.replace("/","_") + set_seed(42) + assert not output_name.exists(), f'{output_name} already exists.' + for reps_name in output_name.parent.iterdir(): + if f'{model_index}_{dataset_index}' in reps_name.name: + raise ValueError(f'Representations for model {model_index} and dataset {dataset_index} already exist in {output_name.parent}') + os.makedirs(output_name.parent, exist_ok=True) - device = "cuda" if torch.cuda.is_available() \ - else "mps" if torch.backends.mps.is_available() \ - else "cpu" dataset = datasets.load_dataset(dataset_name, split=dataset_subset) if dataset_size: dataset = dataset.select(range(dataset_size)) + device = "cuda" if torch.cuda.is_available() \ + else "mps" if torch.backends.mps.is_available() \ + else "cpu" model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", output_hidden_states=True) + model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name) if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token - model.eval() - - set_seed(42) dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) - save_model_and_dataset(model_name, dataset_name) - model_index, dataset_index = model_and_dataset_to_index(model_name, dataset_name) - - output_name = f'{output_dir}/{model_index}_{dataset_index}_id_{np.random.randint(1000)}.h5'.replace("/","_") - - assert not os.path.exists(output_name), f'{output_name} already exists.' - with h5py.File(output_name, 'w') as h5file: for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")): inputs = tokenizer(batch, return_tensors="pt", padding="longest", max_length=max_length, truncation=True).to(device) @@ -90,11 +78,11 @@ def store_representations(model_name, output_dir, dataset_name, batch_size, max_ outputs = model(**inputs) attention_mask = inputs["attention_mask"] hidden_states = outputs.hidden_states - last_non_padded_hidden_states = get_last_non_padded_tokens(hidden_states, attention_mask) - + # Remove the first element to account for the input layer not being considered a model hidden layer - # This adjustment is necessary for analyses focusing on the model's internal transformations - last_non_padded_hidden_states = last_non_padded_hidden_states[1:] + last_non_padded_hidden_states = get_last_non_padded_tokens(hidden_states, attention_mask)[1:] + + last_non_padded_hidden_states = last_non_padded_hidden_states for layer, hidden_state in enumerate(last_non_padded_hidden_states): layer_group = h5file.require_group(f'layer_{layer:03d}') file_name = f'batch_{batch_idx}.pt' @@ -104,6 +92,13 @@ def store_representations(model_name, output_dir, dataset_name, batch_size, max_ # Ensure that the length of last_non_padded_hidden_states matches the number of model hidden layers minus one assert len(last_non_padded_hidden_states) == model.config.num_hidden_layers, "Length of last_non_padded_hidden_states \ does not match expected number of hidden layers." + + if torch.cuda.is_available(): + # Clear GPU memory + del model + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect() @click.command() @click.option('--model_name', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') From 5ac4c788bfdf1ec8e4abf66a571081ae7bf0d93c Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Tue, 6 Aug 2024 14:56:42 +0100 Subject: [PATCH 62/64] refactor folder name --- representations/configs/config.yml | 4 ---- representations/configs/config_comp_all.yml | 2 -- representations/configs/config_comp_corresponding.yml | 2 -- representations/configs/config_i_all.yml | 2 -- representations/configs/config_i_block.yml | 10 ++++------ representations/configs/config_i_single.yml | 2 -- representations/store_representations.py | 2 +- 7 files changed, 5 insertions(+), 19 deletions(-) diff --git a/representations/configs/config.yml b/representations/configs/config.yml index 2677b614..ec438485 100644 --- a/representations/configs/config.yml +++ b/representations/configs/config.yml @@ -1,7 +1,3 @@ -representation_paths: -- "/Users/elliotstein/Documents/Arcee/mergekit/representations/stored_representations/Representations_Qwen_Qwen2-7B-Instruct_microsoft_orca-math-word-problems-200k_4000.h5" -- "/Users/elliotstein/Documents/Arcee/mergekit/representations/stored_representations/Representations_arcee-ai_qwen2-7b-math-tess_microsoft_orca-math-word-problems-200k_4000.h5" - metrics: cosine_similarity: true mse: false diff --git a/representations/configs/config_comp_all.yml b/representations/configs/config_comp_all.yml index d6958d76..0769d932 100644 --- a/representations/configs/config_comp_all.yml +++ b/representations/configs/config_comp_all.yml @@ -1,5 +1,3 @@ -stored_representations: "/workspace/mergekit/representations/stored_representations" - metrics: cosine_similarity: true mse: true diff --git a/representations/configs/config_comp_corresponding.yml b/representations/configs/config_comp_corresponding.yml index 5eedc042..1be20197 100644 --- a/representations/configs/config_comp_corresponding.yml +++ b/representations/configs/config_comp_corresponding.yml @@ -1,5 +1,3 @@ -stored_representations: "/workspace/mergekit/representations/stored_representations" - metrics: cosine_similarity: true mse: true diff --git a/representations/configs/config_i_all.yml b/representations/configs/config_i_all.yml index 5d894605..79398b54 100644 --- a/representations/configs/config_i_all.yml +++ b/representations/configs/config_i_all.yml @@ -1,5 +1,3 @@ -stored_representations: "/workspace/mergekit/representations/stored_representations" - metrics: cosine_similarity: true mse: true diff --git a/representations/configs/config_i_block.yml b/representations/configs/config_i_block.yml index e86e57df..ec438485 100644 --- a/representations/configs/config_i_block.yml +++ b/representations/configs/config_i_block.yml @@ -1,11 +1,9 @@ -stored_representations: "/workspace/mergekit/representations/stored_representations" - metrics: cosine_similarity: true - mse: true - linearity_score: true - cka: true - t-sne: true + mse: false + linearity_score: false + cka: false + t-sne: false analysis_type: "individual" comparison_type: "block" \ No newline at end of file diff --git a/representations/configs/config_i_single.yml b/representations/configs/config_i_single.yml index 191975ff..695b1a38 100644 --- a/representations/configs/config_i_single.yml +++ b/representations/configs/config_i_single.yml @@ -1,5 +1,3 @@ -stored_representations: "/workspace/mergekit/representations/stored_representations" - metrics: cosine_similarity: true mse: true diff --git a/representations/store_representations.py b/representations/store_representations.py index 5d5a9f81..f7c8621d 100644 --- a/representations/store_representations.py +++ b/representations/store_representations.py @@ -102,7 +102,7 @@ def store_representations(model_name, output_dir, dataset_name, batch_size, max_ @click.command() @click.option('--model_name', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') -@click.option('--output_dir', default="./representations/stored_representations", help='folder to store the result in.') +@click.option('--output_dir', default="./representations/representations_store", help='folder to store the result in.') @click.option('--dataset_name', default="arcee-ai/sec-data-mini", help='dataset to use.') @click.option('--batch_size', default=8, help='batch size.') @click.option('--max_length', default=1024, help='maximum length of the input.') From 68a6a3047be81dc0628ef3bfd8f004b63fac8d58 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Wed, 7 Aug 2024 14:42:15 +0100 Subject: [PATCH 63/64] Implement CKNNA and PCA visualisation --- mergekit/_data/models_and_datasets.json | 13 +- mergekit/metric_methods/aggregator_metrics.py | 188 ++++++++++++++---- mergekit/plot_tools/plot_tools.py | 6 +- representations/configs/config.yml | 10 +- representations/configs/config_comp_all.yml | 2 + .../configs/config_comp_corresponding.yml | 2 + representations/configs/config_i_all.yml | 2 + representations/configs/config_i_block.yml | 10 +- representations/configs/config_i_single.yml | 2 + 9 files changed, 180 insertions(+), 55 deletions(-) diff --git a/mergekit/_data/models_and_datasets.json b/mergekit/_data/models_and_datasets.json index 51eb2272..5c49773e 100644 --- a/mergekit/_data/models_and_datasets.json +++ b/mergekit/_data/models_and_datasets.json @@ -1,5 +1,12 @@ { - "model_names": [], - - "dataset_names": [] + "model_names": [ + "Qwen/Qwen2-7B-Instruct", + "arcee-ai/qwen2-7b-math-tess", + "arcee-ai/qwen2-mmamo-2", + "arcee-ai/Qwen2-Merger" + ], + "dataset_names": [ + "microsoft/orca-math-word-problems-200k", + "MuskumPillerum/General-Knowledge" + ] } \ No newline at end of file diff --git a/mergekit/metric_methods/aggregator_metrics.py b/mergekit/metric_methods/aggregator_metrics.py index 25ef602d..fbfdd2c1 100644 --- a/mergekit/metric_methods/aggregator_metrics.py +++ b/mergekit/metric_methods/aggregator_metrics.py @@ -111,7 +111,7 @@ def __init__(self, device: str = "cpu"): super().__init__(device=device) self.iterations = 0 - self.max_iterations = 5 + self.max_iterations = 25 self.A = None self.optimiser = None self.initialised = False @@ -167,50 +167,88 @@ def clear(self) -> None: pass class _CKA(object): - # Class from https://github.com/jayroxis/CKA-similarity/blob/main/CKA.py def __init__(self): - pass - + self.kernel_functions = { + 'inner_product': self.inner_product, + 'rbf': self.rbf + } + + def inner_product(self, X): + return X @ X.T + + def rbf(X, sigma=None): + GX = torch.mm(X, X.t()) + diag_GX = torch.diag(GX).unsqueeze(1) + KX = diag_GX - GX + (diag_GX - GX).t() + + if sigma is None: + mdist = torch.median(KX[KX != 0]) + sigma = torch.sqrt(mdist) + + KX *= -0.5 / (sigma * sigma) + KX = torch.exp(KX) + + return KX + def centering(self, K): n = K.shape[0] - unit = np.ones([n, n]) - I = np.eye(n) + unit = torch.ones([n, n]).to(K.device) + I = torch.eye(n).to(K.device) H = I - unit / n - return np.dot(np.dot(H, K), H) + return H @ K @ H + + def hsic(self, K_x, K_y): + """ + Hilbert-Schmidt Independence Criterion + Input: K_x, K_y: *Centered* Kernel matrices - def rbf(self, X, sigma=None): - GX = np.dot(X, X.T) - KX = np.diag(GX) - GX + (np.diag(GX) - GX).T - if sigma is None: - mdist = np.median(KX[KX != 0]) - sigma = math.sqrt(mdist) - KX *= -0.5 / (sigma * sigma) - KX = np.exp(KX) - return KX + Returns: HSIC(K_x, K_y) + """ + return torch.trace(K_x.T @ K_y) / ((K_x.shape[0]-1) ** 2) - def kernel_HSIC(self, X, Y, sigma): - return np.sum( - self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)) - ) + def cka(self, X, Y, kernel_function='inner_product'): + K_x = self.kernel_functions[kernel_function](X) + K_y = self.kernel_functions[kernel_function](Y) + + K_x = self.centering(K_x) + K_y = self.centering(K_y) - def linear_HSIC(self, X, Y): - L_X = X @ X.T - L_Y = Y @ Y.T - return np.sum(self.centering(L_X) * self.centering(L_Y)) + hsic_xy = self.hsic(K_x, K_y) + hsic_xx = self.hsic(K_x, K_x) + hsic_yy = self.hsic(K_y, K_y) - def linear_CKA(self, X, Y): - hsic = self.linear_HSIC(X, Y) - var1 = np.sqrt(self.linear_HSIC(X, X)) - var2 = np.sqrt(self.linear_HSIC(Y, Y)) + return hsic_xy / torch.sqrt(hsic_xx * hsic_yy) + + def align(self, X, Y, knn_x, knn_y): + """ + Input: X, Y: Centered Kernel matrices + """ + assert X.shape == Y.shape + num_rows, num_cols = X.shape + rows, cols = torch.meshgrid(torch.arange(num_rows), torch.arange(num_cols), indexing='ij') + + # Check if each element in the meshgrid is a mutual nearest neighbor + mutual_nn_mask = torch.isin(rows, knn_x.indices[cols]) & \ + torch.isin(cols, knn_y.indices[rows]) - return hsic / (var1 * var2) + trace_xy = torch.trace(X.T @ Y) + return mutual_nn_mask * trace_xy - def kernel_CKA(self, X, Y, sigma=None): - hsic = self.kernel_HSIC(X, Y, sigma) - var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) - var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) + def cknna(self, X, Y, kernel_function='inner_product', k=5): + K_x = self.kernel_functions[kernel_function](X) + K_y = self.kernel_functions[kernel_function](Y) - return hsic / (var1 * var2) + K_x = self.centering(K_x) + K_y = self.centering(K_y) + + k_nearest_neighbors_x = torch.topk(K_x, k=k, dim=1, largest=True, sorted=False) + k_nearest_neighbors_y = torch.topk(K_y, k=k, dim=1, largest=True, sorted=False) + + align_xy = self.align(K_x, K_y, k_nearest_neighbors_x, k_nearest_neighbors_y) + align_xx = self.align(K_x, K_x, k_nearest_neighbors_x, k_nearest_neighbors_x) + align_yy = self.align(K_y, K_y, k_nearest_neighbors_y, k_nearest_neighbors_y) + + return align_xy / torch.sqrt(align_xx * align_yy) class CKA(MetricAggregator): def __init__(self, device: str = "cpu"): @@ -219,7 +257,37 @@ def __init__(self, device: str = "cpu"): self.batches_a = [] self.batches_b = [] self.stop = False - self.max_batches = 10 + self.max_batches = 20 + + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING.value: True, + LayerComparisonType.ALL.value: True + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + if not self.stop: + self.batches_a.append(batch_a) + self.batches_b.append(batch_b) + + if len(self.batches_a) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + result = self.cka.cka(torch.concat(self.batches_a), + torch.concat(self.batches_b)) + + self.__init__(self.device) # Reset ready for next layer + return Metric(mean_std=MeanStd(mean=result)) + +class CNNKA(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.cka = _CKA() + self.batches_a = [] + self.batches_b = [] + self.stop = False + self.max_batches = 20 self.valid_for.update({ LayerComparisonType.BLOCK.value: True, @@ -229,16 +297,18 @@ def __init__(self, device: str = "cpu"): def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: if not self.stop: - self.batches_a.append(batch_a.cpu().numpy()) - self.batches_b.append(batch_b.cpu().numpy()) + self.batches_a.append(batch_a) + self.batches_b.append(batch_b) if len(self.batches_a) >= self.max_batches: self.stop = True def aggregate(self) -> Metric: - self.result = self.cka.linear_CKA(np.concatenate(self.batches_a), - np.concatenate(self.batches_b)) - return Metric(mean_std=MeanStd(mean=self.result)) + result = self.cka.cknna(torch.concat(self.batches_a), + torch.concat(self.batches_b)) + + self.__init__(self.device) # Reset ready for next layer + return Metric(mean_std=MeanStd(mean=result)) class t_SNE(MetricAggregator): def __init__(self, device: str = "cpu"): @@ -271,11 +341,47 @@ def aggregate(self) -> Metric: ) self.__init__(self.device) # Reset ready for next layer return metric + +class PCA_Projection(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.batches = [] + self.max_batches = 20 + self.stop = False + + self.valid_for.update({ + LayerComparisonType.SINGLE.value: True, + }) + + def process_batch(self, batch: torch.Tensor) -> None: + if not self.stop: + self.batches.append(batch.cpu().numpy()) + + if len(self.batches) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + data = torch.cat(self.batches, dim=0) + mean = torch.mean(data, dim=0) + data -= mean + U, S, V = torch.pca_lowrank(data, q=2) + result = torch.matmul(data, V[:, :2]) + result = result.cpu().numpy() + metric = Metric( + scatter_plot=ScatterPlot( + x=result[:, 0], + y=result[:, 1], + ) + ) + self.__init__(self.device) # Reset ready for next layer + return metric METRICS_TABLE = { 'cosine_similarity': Cosine_Similarity, 'mse': MSE, 'linearity_score': Linearity_Score, 'cka': CKA, - 't-sne': t_SNE + 'cknna': CNNKA, + 't-sne': t_SNE, + 'pca_projection': PCA_Projection } \ No newline at end of file diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index b78479f4..6b93e37b 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -136,7 +136,7 @@ def get_traces(self, data:List, plot_type:str): # Can prob use type of data to d if plot_type == PlotType.HEATMAP.value: traces = [go.Heatmap( z=d.data, - colorscale='RdBu' + colorscale='Bluered', ) for d in data] elif plot_type == PlotType.SCATTER_PLOT.value: traces = [go.Scatter( @@ -338,8 +338,8 @@ def update_line_plot(selected_metric): if getattr(result, 'across_layer_metrics', None): for metric_name, metric in result.across_layer_metrics.items(): for plot_type in ['histogram', 'heatmap', 'scatter_plot']: - if getattr(metric, plot_type, None): #(X) shouldn't need [0] - metric is being stored inside an array and shouldn't be! - id=f"{plot_type}-plot-{metric_name}-{''.join(item for tup in result.representations_details for item in tup)}" # Improve naming scheme, this could get confused with comparison results + if getattr(metric, plot_type, None): + id=f"{plot_type}-plot-{metric_name}-{''.join(item for tup in result.representations_details for item in tup)}" @app.callback( Output(id, 'figure'), diff --git a/representations/configs/config.yml b/representations/configs/config.yml index ec438485..e049ef62 100644 --- a/representations/configs/config.yml +++ b/representations/configs/config.yml @@ -1,9 +1,11 @@ metrics: cosine_similarity: true - mse: false - linearity_score: false - cka: false - t-sne: false + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true analysis_type: "individual" comparison_type: "block" \ No newline at end of file diff --git a/representations/configs/config_comp_all.yml b/representations/configs/config_comp_all.yml index 0769d932..16b04e8b 100644 --- a/representations/configs/config_comp_all.yml +++ b/representations/configs/config_comp_all.yml @@ -3,7 +3,9 @@ metrics: mse: true linearity_score: true cka: true + cknna: true t-sne: true + pca_projection: true analysis_type: "comparison" comparison_type: "all" \ No newline at end of file diff --git a/representations/configs/config_comp_corresponding.yml b/representations/configs/config_comp_corresponding.yml index 1be20197..0cf8a5d9 100644 --- a/representations/configs/config_comp_corresponding.yml +++ b/representations/configs/config_comp_corresponding.yml @@ -3,7 +3,9 @@ metrics: mse: true linearity_score: true cka: true + cknna: true t-sne: true + pca_projection: true analysis_type: "comparison" comparison_type: "corresponding" \ No newline at end of file diff --git a/representations/configs/config_i_all.yml b/representations/configs/config_i_all.yml index 79398b54..e2764960 100644 --- a/representations/configs/config_i_all.yml +++ b/representations/configs/config_i_all.yml @@ -3,7 +3,9 @@ metrics: mse: true linearity_score: true cka: true + cknna: true t-sne: true + pca_projection: true analysis_type: "individual" comparison_type: "all" \ No newline at end of file diff --git a/representations/configs/config_i_block.yml b/representations/configs/config_i_block.yml index ec438485..e049ef62 100644 --- a/representations/configs/config_i_block.yml +++ b/representations/configs/config_i_block.yml @@ -1,9 +1,11 @@ metrics: cosine_similarity: true - mse: false - linearity_score: false - cka: false - t-sne: false + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true analysis_type: "individual" comparison_type: "block" \ No newline at end of file diff --git a/representations/configs/config_i_single.yml b/representations/configs/config_i_single.yml index 695b1a38..fe653e8d 100644 --- a/representations/configs/config_i_single.yml +++ b/representations/configs/config_i_single.yml @@ -3,7 +3,9 @@ metrics: mse: true linearity_score: true cka: true + cknna: true t-sne: true + pca_projection: true analysis_type: "individual" comparison_type: "single" \ No newline at end of file From 20dd0f670c48303e3acac9b3376f568695d88558 Mon Sep 17 00:00:00 2001 From: ElliotStein Date: Wed, 7 Aug 2024 17:36:17 +0100 Subject: [PATCH 64/64] fixes to CKNNA and tidy up --- mergekit/metric_methods/aggregator_metrics.py | 27 +++++++++---------- mergekit/metric_methods/base.py | 2 +- mergekit/plot_tools/plot_tools.py | 3 +++ representations/experiment_setup.py | 7 ++--- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mergekit/metric_methods/aggregator_metrics.py b/mergekit/metric_methods/aggregator_metrics.py index fbfdd2c1..fb317406 100644 --- a/mergekit/metric_methods/aggregator_metrics.py +++ b/mergekit/metric_methods/aggregator_metrics.py @@ -39,8 +39,8 @@ def process_batch(self, batch_a: torch.Tensor, batch_b: Optional[torch.Tensor]) def aggregate(self) -> Metric: raise NotImplementedError - def clear(self) -> None: - raise NotImplementedError + # def clear(self) -> None: + # raise NotImplementedError class Cosine_Similarity(MetricAggregator): def __init__(self, device: str = "cpu"): @@ -73,8 +73,8 @@ def aggregate(self) -> Metric: mean_std=mean_std ) - def clear(self) -> None: - self.cosine_similarities = torch.tensor([], device=self.device) + # def clear(self) -> None: + # self.cosine_similarities = torch.tensor([], device=self.device) class MSE(MetricAggregator): def __init__(self, device: str = "cpu"): @@ -111,7 +111,7 @@ def __init__(self, device: str = "cpu"): super().__init__(device=device) self.iterations = 0 - self.max_iterations = 25 + self.max_iterations = 50 self.A = None self.optimiser = None self.initialised = False @@ -163,8 +163,8 @@ def aggregate(self) -> Metric: self.__init__() return Metric(mean_std=MeanStd(mean=linearity_score)) - def clear(self) -> None: - pass + # def clear(self) -> None: + # pass class _CKA(object): def __init__(self): @@ -228,11 +228,10 @@ def align(self, X, Y, knn_x, knn_y): rows, cols = torch.meshgrid(torch.arange(num_rows), torch.arange(num_cols), indexing='ij') # Check if each element in the meshgrid is a mutual nearest neighbor - mutual_nn_mask = torch.isin(rows, knn_x.indices[cols]) & \ - torch.isin(cols, knn_y.indices[rows]) + mutual_nn_mask = torch.isin(rows.to(knn_x[0].device), knn_x.indices[cols]) & \ + torch.isin(cols.to(knn_x[0].device), knn_y.indices[rows]) - trace_xy = torch.trace(X.T @ Y) - return mutual_nn_mask * trace_xy + return torch.sum(mutual_nn_mask * X * Y) def cknna(self, X, Y, kernel_function='inner_product', k=5): K_x = self.kernel_functions[kernel_function](X) @@ -278,7 +277,7 @@ def aggregate(self) -> Metric: torch.concat(self.batches_b)) self.__init__(self.device) # Reset ready for next layer - return Metric(mean_std=MeanStd(mean=result)) + return Metric(mean_std=MeanStd(mean=result.cpu().item())) class CNNKA(MetricAggregator): def __init__(self, device: str = "cpu"): @@ -308,7 +307,7 @@ def aggregate(self) -> Metric: torch.concat(self.batches_b)) self.__init__(self.device) # Reset ready for next layer - return Metric(mean_std=MeanStd(mean=result)) + return Metric(mean_std=MeanStd(mean=result.cpu().item())) class t_SNE(MetricAggregator): def __init__(self, device: str = "cpu"): @@ -355,7 +354,7 @@ def __init__(self, device: str = "cpu"): def process_batch(self, batch: torch.Tensor) -> None: if not self.stop: - self.batches.append(batch.cpu().numpy()) + self.batches.append(batch) if len(self.batches) >= self.max_batches: self.stop = True diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py index 12935f7d..cdc655d8 100644 --- a/mergekit/metric_methods/base.py +++ b/mergekit/metric_methods/base.py @@ -58,7 +58,7 @@ class MeanStd: @dataclass class Heatmap: data: torch.Tensor - update_layout_options: Optional[Dict] = None + plot_details: Optional[Dict] = None @dataclass class Histogram: diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py index 6b93e37b..6fb3d204 100644 --- a/mergekit/plot_tools/plot_tools.py +++ b/mergekit/plot_tools/plot_tools.py @@ -349,6 +349,9 @@ def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric, rep_de traces = results_handler.get_traces(data = [getattr(metric, plot_type)], plot_type = plot_type) xaxis_title = f"Model {rep_details[0][0]} {rep_details[0][1]}" yaxis_title = f"Model {rep_details[1][0]} {rep_details[1][1]}" if len(rep_details) > 1 else xaxis_title + if hasattr(metric, 'plot_details'): + xaxis_title = metric.plot_details['xaxis_title'] if 'xaxis_title' in metric.plot_details else xaxis_title + yaxis_title = metric.plot_details['yaxis_title'] if 'yaxis_title' in metric.plot_details else yaxis_title return create_figure(traces=traces, title=f"{id} | {metric_name}", xaxis_title=xaxis_title, diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py index d8108995..ca3a61f4 100644 --- a/representations/experiment_setup.py +++ b/representations/experiment_setup.py @@ -278,13 +278,13 @@ def run(self, config: Configuration): stop = True with LayerByIndex(rep_0) as representations_0, LayerByIndex(rep_1) as representations_1: + comparison_results.load_representations_details_from_path(rep_0) + comparison_results.load_representations_details_from_path(rep_1) comparison_results = corresponding(representations_0=representations_0, representations_1=representations_1, metric_classes=metrics, results=comparison_results, device=config.device) - comparison_results.load_representations_details_from_path(rep_0) - comparison_results.load_representations_details_from_path(rep_1) comparison_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") @@ -315,7 +315,8 @@ def run(self, config: Configuration): for metric in metrics: block_results.across_layer_metrics[metric().__class__.__name__.lower()] = Metric( heatmap = Heatmap( - data = convert_to_2d_array(heatmaps[metric().__class__.__name__.lower()]) # Definitely a simpler way to code this (X) + data = convert_to_2d_array(heatmaps[metric().__class__.__name__.lower()]), # Definitely a simpler way to code this (X) + plot_details={'title': f'{metric().__class__.__name__} across N-blocks', 'xlabel': 'Layer', 'ylabel': 'Block Size'} ) ) block_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}")