Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights Metrics #340

Open
wants to merge 67 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
c2b0f06
introduced metric_methods, closely following implementation of merge_…
May 30, 2024
64e54bf
measure.py closely follows metric.py to apply chosen metrics
May 30, 2024
7e37d69
plot tools include class to handle output of run_metrics, and a class…
May 30, 2024
2effd41
only minor changes required to existing mergekit code
May 30, 2024
ed59d9b
only minor changes required to existing mergekit
May 30, 2024
564d45c
Implemented interactive dashboard for metrics visualisation
Jun 3, 2024
e530568
Remove single-model stats for now. Bring MSE into all_metrics
Jun 3, 2024
207e874
Introduce attention weights and restructure dashboard
Jun 7, 2024
b30175e
refine implementation of attention metrics, add line plots to dashboard
Jun 10, 2024
dfc0603
More restructuring, more seamless integration of attention and mlp la…
Jun 11, 2024
f88f904
vectorise heatmap computation
Jun 12, 2024
f89df57
Address issue with lexicographical sort by adding leading zeros to la…
Jun 14, 2024
74c5d33
remove unnecessary import from last commit
Jun 14, 2024
7c209d2
add validation check to ensure MergeConfiguration method is either Me…
Jun 17, 2024
db65f83
moved run_metrics and use click to enable commandline control over ar…
Jun 17, 2024
3fe66a8
rename example config
Jun 17, 2024
62692d1
correct case for gqa_group name
Jun 17, 2024
2a0c520
replace measure with merge + early out
Jun 17, 2024
f71e360
guard against divide by zero
Jun 17, 2024
7e14266
restore plan_to_disk functionality for merging. Move metrics planning…
Jun 17, 2024
0f3430f
add optional interactive plot packages
Jun 17, 2024
8d68e39
minor cleanup
Jun 17, 2024
59e23fe
Merge remote-tracking branch 'upstream/main'
Jun 17, 2024
89ecbf5
Merge branch 'main' into weights_metrics
Jun 17, 2024
404e395
Add GQA info to (llama) architecture and refactor
Jun 19, 2024
8e3c861
Pass GQA info from architecture json all the way to attn metrics. Gen…
Jun 19, 2024
a0e8c27
re-organised and simplified dashboard view
Jun 19, 2024
7e2b552
colour-categorise lineplot points by layertime
Jun 19, 2024
69c3b15
restructure and refactor results and plotting
Jun 24, 2024
f36fd70
restructure metrics storage, remove graph from plot, remove redundanc…
Jun 26, 2024
50a5716
Add intra-layer metrics, completed implementation of changes from pre…
Jun 26, 2024
73dd3ae
restructure metrics for modularity
ElliotStein Jul 4, 2024
ddbb475
add load and save functions to base
ElliotStein Jul 4, 2024
36d28b0
Internal Representaions analysis - first commit
ElliotStein Jul 4, 2024
fa6d098
refactor variable name for consistency
ElliotStein Jul 8, 2024
4949458
Plot heatmaps stored in results.others
ElliotStein Jul 8, 2024
bd62ed9
generalised results handler to load from metrics list or ready-made R…
ElliotStein Jul 8, 2024
9c87514
abstracted and add skip block analysis
ElliotStein Jul 9, 2024
4ce67b0
refactor and restructure
ElliotStein Jul 9, 2024
88ed18e
clean up imports and remove hard coding
ElliotStein Jul 9, 2024
1041298
tidy up tqdm
ElliotStein Jul 9, 2024
01d5a2b
improve robustness of load and save using pathlib
ElliotStein Jul 9, 2024
f64a71e
reintroduced heatmap functionality
ElliotStein Jul 9, 2024
8b05c2a
allow for plot keyworks to be passed into Heatmap object
ElliotStein Jul 9, 2024
b8f7e32
example config
ElliotStein Jul 9, 2024
c4df572
improve implementation consistency
ElliotStein Jul 9, 2024
5e051a8
experimental linearity score metric
ElliotStein Jul 16, 2024
bb8fce2
add matplotlib to optional dependencies
ElliotStein Jul 16, 2024
9c92efd
Merge remote-tracking branch 'upstream/main' into weights_metrics
ElliotStein Jul 16, 2024
e1c1ecd
remove quantisation and update environment reqs
ElliotStein Jul 16, 2024
9db9d3b
tidy up and add missing dependency
ElliotStein Jul 22, 2024
af2bf1a
Major restructuring of Results, Results handling, metrics
ElliotStein Jul 25, 2024
6045980
address alphanumeric representation layer name ordering issue
ElliotStein Jul 30, 2024
cb5e31a
MAJOR RESTRUCTURE of results, results handler and representation metrics
ElliotStein Jul 30, 2024
dfdcb00
tidy up
ElliotStein Jul 30, 2024
7bc8001
further restructuring
ElliotStein Jul 31, 2024
a5c5811
further refinements to representations experiment
ElliotStein Aug 1, 2024
3d3fd33
added necessary imports and modules
ElliotStein Aug 1, 2024
2dfa4f9
minor restructuring and refactoring
ElliotStein Aug 1, 2024
60726d4
Bug fixes
ElliotStein Aug 1, 2024
891bbb4
visualisation fixes
ElliotStein Aug 2, 2024
04b64e8
address model path naming issue (but only for representations! Not mo…
ElliotStein Aug 2, 2024
1b5a3c0
minor fix
ElliotStein Aug 5, 2024
6bd12e7
More fixes, end-to-end tested and working
ElliotStein Aug 6, 2024
5ac4c78
refactor folder name
ElliotStein Aug 6, 2024
68a6a30
Implement CKNNA and PCA visualisation
ElliotStein Aug 7, 2024
20dd0f6
fixes to CKNNA and tidy up
ElliotStein Aug 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
*.pkl
*.h5
offload_folder/

# Environment

mergekit/bin/
mergekit/share/
mergekit/etc/
merged/
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
8 changes: 8 additions & 0 deletions examples/metrics-llama-1v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
models:
- model: huggyllama/llama-7b
- model: TheBloke/Llama-2-7B-fp16

metric_method: all
parameters:
intra_model_metrics: true
dtype: float32
6 changes: 6 additions & 0 deletions examples/metrics-llama-codev2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
models:
- model: meta-llama/CodeLlama-7b-Python-hf
- model: meta-llama/Llama-2-7b-hf

metric_method: all
dtype: float32
9 changes: 9 additions & 0 deletions examples/metrics-small.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
models:
- model: BEE-spoke-data/smol_llama-220M-GQA
- model: BEE-spoke-data/smol_llama-220M-openhermes

metric_method: all
parameters:
intra_model_metrics: true
inter_model_metrics: true
dtype: float32
10 changes: 8 additions & 2 deletions mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,27 @@
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight",
"input_space": "h_${layer_index}",
"num_attention_heads": "${num_attention_heads}",
"output_space": "attn_qk_${layer_index}"
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_qk_${layer_index}"
"output_space": "attn_qk_${layer_index}",
"num_attention_heads": "${num_attention_heads}",
"num_key_value_heads": "${num_key_value_heads}"
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_v_${layer_index}"
"output_space": "attn_v_${layer_index}",
"num_attention_heads": "${num_attention_heads}",
"num_key_value_heads": "${num_key_value_heads}"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight",
"input_space": "attn_v_${layer_index}",
"num_attention_heads": "${num_attention_heads}",
"output_space": "post_attn_${layer_index}"
},
{
Expand Down
12 changes: 12 additions & 0 deletions mergekit/_data/models_and_datasets.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"model_names": [
"Qwen/Qwen2-7B-Instruct",
"arcee-ai/qwen2-7b-math-tess",
"arcee-ai/qwen2-mmamo-2",
"arcee-ai/Qwen2-Merger"
],
"dataset_names": [
"microsoft/orca-math-word-problems-200k",
"MuskumPillerum/General-Knowledge"
]
}
45 changes: 45 additions & 0 deletions mergekit/_data/models_and_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
from pathlib import Path

def save_model_and_dataset(model_name, dataset_name):
models_and_datasets = presets().load()
if model_name not in models_and_datasets['model_names']:
models_and_datasets['model_names'].append(model_name)
if dataset_name not in models_and_datasets['dataset_names']:
models_and_datasets['dataset_names'].append(dataset_name)
presets().save(models_and_datasets['model_names'], models_and_datasets['dataset_names'])

def model_and_dataset_to_index(model_name, dataset_name):
models_and_datasets = presets().load()
model_index = models_and_datasets['model_names'].index(model_name) if model_name in models_and_datasets['model_names'] else []
dataset_index = models_and_datasets['dataset_names'].index(dataset_name) if dataset_name in models_and_datasets['dataset_names'] else []

return model_index, dataset_index

def index_to_model_and_dataset(model_index, dataset_index):
models_and_datasets = presets().load()
model_name = models_and_datasets['model_names'][model_index] if len(models_and_datasets['model_names']) > model_index else []
dataset_name = models_and_datasets['dataset_names'][dataset_index] if len(models_and_datasets['dataset_names']) > dataset_index else []
return model_name, dataset_name

class presets():
def __init__(self):
self.FILE_PATH = Path(__file__).parent / 'models_and_datasets.json'

def load(self):
"""Load the lists from the JSON file."""
if self.FILE_PATH.exists():
with open(self.FILE_PATH, 'r') as file:
data = json.load(file)
return data
print(f"File {self.FILE_PATH} does not exist or is empty.")
return {}

def save(self, model_names, dataset_names):
"""Save the lists to the JSON file."""
data = {
'model_names': model_names,
'dataset_names': dataset_names
}
with open(self.FILE_PATH, 'w') as file:
json.dump(data, file, indent=4)
53 changes: 32 additions & 21 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class WeightInfo(BaseModel, frozen=True):
List of alternative names for the weight, if applicable.
force_dtype (Optional[str]):
Mandatory dtype for the weight, if applicable.
num_key_value_heads (Optional[int]):
Number of key-value heads in the weight, relevant for GQA, if applicable.
num_attention_heads (Optional[int]):
Number of attention heads in the weight, if applicable.
"""

name: str
Expand All @@ -53,6 +57,9 @@ class WeightInfo(BaseModel, frozen=True):
aliases: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None

num_key_value_heads: Union[int, str, None] = None
num_attention_heads: Union[int, str, None] = None


class ProceduralSpaceInfo(BaseModel, frozen=True):
"""Defines a procedural space computed from one or more other spaces.
Expand Down Expand Up @@ -172,30 +179,29 @@ class JSONArchitectureDefinition(BaseModel, frozen=True):
class TemplateWithArithmetic(string.Template):
idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)"


def _template_substitution(
template: str, num_layers: int, layer_idx: Optional[int] = None
def _generic_template_substitution(
template: str, match_str: str = "num_layers", replace_with: int = 0
) -> str:
if "{" not in template:
if f"{match_str}" not in template:
return template

substitutions = {
"num_layers": num_layers,
"num_layers+1": num_layers + 1,
"num_layers-1": num_layers - 1,
match_str: replace_with,
f"{match_str}+1": replace_with + 1,
f"{match_str}-1": replace_with - 1,
}

if layer_idx is not None:
substitutions.update(
{
"layer_index": layer_idx,
"layer_index+1": layer_idx + 1,
"layer_index-1": layer_idx - 1,
}
)

return TemplateWithArithmetic(template).substitute(substitutions)

def _template_substitution(
template: str, substitute: Dict[str, Optional[int]]
) -> str:
for key, value in substitute.items():
if value is not None:
template = _generic_template_substitution(template, key, value)

return template


class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True):
definition: JSONArchitectureDefinition
Expand All @@ -206,18 +212,23 @@ def _substitute(
config: PretrainedConfig,
layer_idx: Optional[int] = None,
) -> Union[WeightInfo, ProceduralSpaceInfo]:
num_layers = self.num_layers(config)

obj_dict = item.model_dump(mode="json", exclude_unset=True)
substitute = {
"num_layers": self.num_layers(config),
"layer_index": layer_idx,
"num_attention_heads": getattr(config, "num_attention_heads") if getattr(config, "num_attention_heads") is not None else None,
"num_key_value_heads": getattr(config, "num_key_value_heads") if getattr(config, "num_key_value_heads") is not None else None,
}

obj_dict = item.model_dump(mode="json", exclude_unset=False)
for key in obj_dict:
if isinstance(obj_dict[key], str):
obj_dict[key] = _template_substitution(
obj_dict[key], num_layers, layer_idx
obj_dict[key], substitute
)
elif isinstance(obj_dict[key], list):
obj_dict[key] = [
(
_template_substitution(s, num_layers, layer_idx)
_template_substitution(s, substitute)
if isinstance(s, str)
else s
)
Expand Down
11 changes: 10 additions & 1 deletion mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class OutputSliceDefinition(BaseModel):


class MergeConfiguration(BaseModel):
merge_method: str
merge_method: Optional[str] = None
metric_method: Optional[str] = None
slices: Optional[List[OutputSliceDefinition]] = None
models: Optional[List[InputModelDefinition]] = None
parameters: Optional[Dict[str, ParameterSetting]] = None
Expand Down Expand Up @@ -114,6 +115,14 @@ def validate_inputs(self):
if ((not self.slices) and (not self.models)) or (self.slices and self.models):
raise RuntimeError("Must specify either output slices or models to merge")
return self

@model_validator(mode="after")
def validate_methods(self):
if not self.merge_method and not self.metric_method:
raise ValueError("Either 'merge_method' or 'metric_method' must be provided.")
if self.merge_method and self.metric_method:
raise ValueError("Only one of 'merge_method' or 'metric_method' can be provided, not both.")
return self

@model_validator(mode="after")
def validate_tokenizer(self):
Expand Down
17 changes: 12 additions & 5 deletions mergekit/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
Abstract base class representing a task in a computational graph.

This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy.
Pydantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after.

Attributes:
Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution.
Expand Down Expand Up @@ -106,7 +107,6 @@ def uses_accelerator(self) -> bool:
"""
return False


class Executor:
"""
Schedules and executes a set of tasks and their dependencies.
Expand Down Expand Up @@ -241,13 +241,20 @@ def _make_schedule(self, targets: List[Task]) -> List[Task]:
# they will be included in the final schedule
edge_tups.append((Executor.DUMMY_TASK_VALUE, task))

def _pad_numbers(s):
parts = s.split('.')
for i, part in enumerate(parts):
if part.isdigit():
parts[i] = part.zfill(3)
return '.'.join(parts)

def _compare_key(task: Union[Task, str]):
if task == Executor.DUMMY_TASK_VALUE:
return ("", 0)
return (
task.group_label() or "",
-task.priority(),
)
group_label = task.group_label() or ""
padded_label = _pad_numbers(group_label)
priority = -task.priority()
return (padded_label, priority)

graph = networkx.DiGraph(edge_tups)
res = [
Expand Down
6 changes: 6 additions & 0 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,17 @@ def run_merge(
storage_device="cuda" if options.low_cpu_memory else "cpu",
)

metrics_out = []
tokenizer = None
for _task, value in exec.run(quiet=options.quiet):
if merge_config.metric_method is not None:
metrics_out.append((_task, value))
if isinstance(value, TokenizerInfo):
tokenizer = value.tokenizer

if metrics_out:
return metrics_out

if tokenizer:
_update_config_vocab(cfg_out, tokenizer)

Expand Down
29 changes: 29 additions & 0 deletions mergekit/metric_methods/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from mergekit.metric_methods.base import MetricMethod
from mergekit.metric_methods.all_metrics import AllMetric


def get(method: str) -> MetricMethod:
if method == "all":
return AllMetric()
raise RuntimeError(f"Unimplemented metric method {method}")


__all__ = [
"MetricMethod",
"get",
]
Loading