Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights Metrics #340

Open
wants to merge 67 commits into
base: main
Choose a base branch
from
Open

Conversation

ElliotStein
Copy link

Implemented:

  • Framework to compute metrics based on layer weights using existing mergekit infrastructure (run_measure is based on run_merge, metric_methods based on merge_methods etc).
  • plot_tools.MetricsHandler to load metrics output, process and interact with statistics.
  • plot_tools.ModelGraph to generate a graph to represent the model structure, with node level statistics visible by hovering over a node, and more detailed stats (histograms, rather than means) available by clicking on a node.
  • run_metrics.py ties everything together and generates an interactive dashboard displaying the ModelGraph graph.

Not Implemented:

  • Split layers into individual heads
  • Activation based metrics
  • Unit tests

run_metrics.py Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should probably be in mergekit/scripts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would be good to use click to turn the hardcoded values into arguments.

from typing import List, Dict, Optional, Any, Tuple
from mergekit.graph import Task
import networkx as nx
import plotly.graph_objects as go
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should capture these new dependencies in pyproject.toml. Probably under a feature, so headless installs don't need to bring them in.

mergekit/plan.py Outdated
)
finalize = FinalizeModel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally fine to not do the finalize task when we're doing metrics, but this is needed for merges - I think as is this makes merges not write out correctly.

**_kwargs,
) -> Task:

if 'self_attn' in output_weight.name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down the line we probably want this split to be done based on new fields in ArchitectureInfo but this is good for now!


res = {}

scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be doing something here to guard against dividing by zero?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep - norms are non-negative so adding small epsilon will be fine

@@ -53,6 +57,9 @@ class WeightInfo(BaseModel, frozen=True):
aliases: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None

GQA_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be gqa_groups

num_heads=32 # hard-coded for now
)
self.block_count += 1
return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this end up creating N AttnTasks for each block? I don't think it's actually a problem as the tasks will be deduplicated downstream - should be fine

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should only be one AttnTask for each block - the if statement on line 351 is only satisfied once all the tensors (K,Q,V,O) have been collected. Then self.attn_weight_dict is reset to {} and the (one) AttnTask is created. I might also add individual tensor metrics for comparing just the Qs, Vs etc, which would be simpler.

self._method = merge_methods.get(config.merge_method)
if getattr(config, "merge_method", None):
self._method = merge_methods.get(config.merge_method)
elif getattr(config, "metric_method", None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a validator to MergeConfig that checks that exactly one of these fields is set.

)

res = []
for _task, value in exec.run(quiet=options.quiet):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking this over, I kinda think we might not need a separate file here - maybe it should just early out in merge.py if there's a metric_method set instead of merge_method?

@@ -37,6 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
Abstract base class representing a task in a computational graph.

This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy.
Note that PyDantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nitpick here: I think the official capitalization is Pydantic, not PyDantic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants