Skip to content

Commit

Permalink
Add merge tests and fix various problems
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 31, 2023
1 parent 5e6e453 commit b184ae2
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 53 deletions.
40 changes: 18 additions & 22 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class ConfigReader(BaseModel):
t: float
tensor_name: Optional[str] = None
slice_out: Optional[OutputSliceDefinition] = None
slices_in: Optional[List[InputSliceDefinition]] = None

@property
def base_model(self) -> Optional[ModelReference]:
Expand All @@ -135,16 +134,6 @@ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader":
t=self.t,
tensor_name=self.tensor_name,
slice_out=slice,
slices_in=self.slices_in,
)

def for_in_slices(self, slices: List[InputSliceDefinition]) -> "ConfigReader":
return ConfigReader(
config=self.config,
t=self.t,
tensor_name=self.tensor_name,
slice_out=self.slice_out,
slices_in=slices,
)

def for_tensor(self, tensor_name: str) -> "ConfigReader":
Expand All @@ -153,7 +142,6 @@ def for_tensor(self, tensor_name: str) -> "ConfigReader":
t=self.t,
tensor_name=tensor_name,
slice_out=self.slice_out,
slices_in=self.slices_in,
)

def with_t(self, t: float) -> "ConfigReader":
Expand All @@ -162,7 +150,6 @@ def with_t(self, t: float) -> "ConfigReader":
t=t,
tensor_name=self.tensor_name,
slice_out=self.slice_out,
slices_in=self.slices_in,
)

def parameter(
Expand All @@ -172,16 +159,16 @@ def parameter(
default: Any = None,
required: bool = False,
) -> Any:
if model and self.slices_in:
for s in self.slices_in:
if s.model == str(model) and s.parameters and name in s.parameters:
value = evaluate_setting(
self.tensor_name, s.parameters[name], self.t
)
if value is not None:
return value

if self.slice_out:
if model:
for s in self.slice_out.sources:
if s.model == str(model) and s.parameters and name in s.parameters:
value = evaluate_setting(
self.tensor_name, s.parameters[name], self.t
)
if value is not None:
return value

if self.slice_out.parameters and name in self.slice_out.parameters:
value = evaluate_setting(
self.tensor_name, self.slice_out.parameters[name], self.t
Expand Down Expand Up @@ -216,5 +203,14 @@ def parameter(
path_paths = [str(s) for s in [model, self.tensor_name] if s]
p = ".".join(path_paths)
suffix = f" for {p}" if p else ""
print(f"name: {name}")
print(f"model: {model}")
print(f"slice_out: {self.slice_out}")
for s in self.slice_out.sources:
print(repr(s))
print(s.model == str(model))
print(bool(s.parameters))
print(name in s.parameters)
print(repr(self))
raise RuntimeError(f"Missing required parameter {name}{suffix}")
return default
8 changes: 4 additions & 4 deletions mergekit/merge_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

from pydantic import BaseModel

from mergekit.common import ModelReference
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.tasks import GatherTensors

Expand All @@ -42,8 +42,8 @@ def make_task(
*,
output_tensor_name: str,
tensors: GatherTensors,
parameters: Dict[str, Any],
tensor_parameters: Dict[ModelReference, Dict[str, Any]],
parameters: ImmutableMap[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
base_model: Optional[ModelReference],
) -> Task:
...
12 changes: 6 additions & 6 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ConsensusMethod(str, Enum):
sum = "sum"


class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel):
class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True):
consensus_method: Optional[ConsensusMethod]
sparsification_method: Optional[SparsificationMethod]
default_normalize: bool
Expand All @@ -57,14 +57,14 @@ def make_task(
output_tensor_name: str,
tensors: GatherTensors,
base_model: Optional[ModelReference],
parameters: Dict[str, Any],
tensor_parameters: Dict[ModelReference, Dict[str, Any]],
parameters: ImmutableMap[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
) -> Task:
return GTATask(
method=self,
tensors=tensors,
base_model=base_model,
tensor_parameters=ImmutableMap(data=tensor_parameters),
tensor_parameters=tensor_parameters,
int8_mask=parameters["int8_mask"],
normalize=parameters["normalize"],
out_tensor_name=output_tensor_name,
Expand Down Expand Up @@ -141,8 +141,8 @@ def execute(
def get_task_vectors(
parameter_name: str,
base_model: ModelReference,
tensors: Dict[ModelReference, torch.Tensor],
tensor_parameters: Dict[ModelReference, Dict[str, Any]],
tensors: ImmutableMap[ModelReference, torch.Tensor],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
keys = list(tensors.keys())
base = tensors[base_model]
Expand Down
6 changes: 3 additions & 3 deletions mergekit/merge_methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class LinearMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
tensor_parameters: ImmutableMap[ModelReference, Dict[str, Any]]
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
normalize: bool
parameter_name: str

Expand Down Expand Up @@ -74,12 +74,12 @@ def make_task(
output_tensor_name: str,
tensors: GatherTensors,
parameters: Dict[str, Any],
tensor_parameters: Dict[ModelReference, Dict[str, Any]],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**_kwargs,
) -> Task:
return LinearMergeTask(
gather_tensors=tensors,
tensor_parameters=ImmutableMap(data=tensor_parameters),
tensor_parameters=tensor_parameters,
normalize=parameters["normalize"],
parameter_name=output_tensor_name,
)
4 changes: 2 additions & 2 deletions mergekit/merge_methods/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
from torch._tensor import Tensor

from mergekit.common import ModelReference
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.tasks import GatherTensors
Expand Down Expand Up @@ -50,7 +50,7 @@ def make_task(
self,
*,
tensors: GatherTensors,
parameters: Dict[str, Any],
parameters: ImmutableMap[str, Any],
**kwargs,
) -> Task:
return PassthroughMergeTask(gather_tensors=tensors, scale=parameters["scale"])
16 changes: 9 additions & 7 deletions mergekit/merge_methods/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch._tensor import Tensor

from mergekit.common import ModelReference, rectify_embed_sizes
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
from mergekit.graph import Task
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.tasks import GatherTensors
Expand All @@ -34,13 +34,15 @@ class SlerpTask(Task[torch.Tensor]):
def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}

def execute(self, input_tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
if len(input_tensors) == 1:
return list(input_tensors.values())[0]
elif len(input_tensors) != 2:
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
if len(tensors) == 1:
return list(tensors.values())[0]
elif len(tensors) != 2:
raise RuntimeError("Slerp merge expects exactly two models")
elif self.base_model not in tensors:
raise RuntimeError("Base model not in input tensors")

[a, b] = list(input_tensors.items())
[a, b] = list(tensors.items())
if a[0] != self.base_model:
[a, b] = [b, a]
prepped_tensors = [a[1], b[1]]
Expand All @@ -67,7 +69,7 @@ def make_task(
*,
output_tensor_name: str,
tensors: GatherTensors,
parameters: Dict[str, Any],
parameters: ImmutableMap[str, Any],
base_model: ModelReference | None,
**_kwargs,
) -> Task:
Expand Down
4 changes: 2 additions & 2 deletions mergekit/merge_methods/tokenizer_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def make_task(
*,
tensors: GatherTensors,
parameters: Dict[str, Any],
tensor_parameters: Dict[ModelReference, Dict[str, Any]],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
base_model: Optional[ModelReference],
**_kwargs,
) -> Task:
Expand All @@ -136,5 +136,5 @@ def make_task(
gather_tensors=tensors,
use_slerp=parameters["embed_slerp"],
slerp_t=parameters["t"],
tensor_parameters=ImmutableMap(data=tensor_parameters),
tensor_parameters=tensor_parameters,
)
81 changes: 74 additions & 7 deletions mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from typing import List, Optional

from mergekit import merge_methods
Expand All @@ -35,6 +36,7 @@ class MergePlanner:
config: MergeConfiguration
arch_info: ArchitectureInfo
clone_tensors: bool
trust_remote_code: bool
_writer_task: TensorWriterTask
_method: MergeMethod
_tasks: List[Task] = []
Expand All @@ -51,6 +53,7 @@ def __init__(
self.config = config
self.arch_info = arch_info
self.clone_tensors = options.clone_tensors
self.trust_remote_code = options.trust_remote_code
self._method = merge_methods.get(config.merge_method)
self._writer_task = TensorWriterTask(
out_path=out_path, max_shard_size=options.out_shard_size
Expand All @@ -68,6 +71,53 @@ def __init__(
trust_remote_code=options.trust_remote_code,
)

def normalize_config(self):
base_model = (
ModelReference.parse(self.config.base_model)
if self.config.base_model
else None
)

# if models to merge are specified instead of output slices, compute them
if self.config.models:
if self.config.slices:
raise RuntimeError(
"Must specify either models to merge or output slices"
)

slices_in = []
base_included = False

for model_in in self.config.models:
mref = ModelReference.parse(model_in.model)

if base_model and mref == base_model:
base_included = True

model_cfg = mref.config(trust_remote_code=self.trust_remote_code)
num_layers = self.arch_info.num_layers(model_cfg)
slices_in.append(
InputSliceDefinition(
layer_range=[0, num_layers],
model=model_in.model,
parameters=model_in.parameters,
)
)

if base_model and not base_included:
logging.info("Base model specified but not in input models - adding")
base_cfg = base_model.config(trust_remote_code=self.trust_remote_code)
num_layers = self.arch_info.num_layers(base_cfg)
slices_in.append(
InputSliceDefinition(
layer_range=[0, num_layers],
model=str(base_model),
)
)

self.config.slices = [OutputSliceDefinition(sources=slices_in)]
self.config.models = None

def plan_tensor(
self,
name: str,
Expand All @@ -82,7 +132,7 @@ def plan_tensor(
tokenizer_task=self._tokenizer_task
)

cfg_g = cfg_reader.for_in_slices(None).for_tensor(name)
cfg_g = cfg_reader.for_tensor(name)
global_params = {}
for p in tensor_merge_method.parameters():
global_params[p.name] = cfg_g.parameter(
Expand All @@ -91,11 +141,15 @@ def plan_tensor(

tensor_params = {}
for model, name_in in zip(models, names_in):
is_base = str(model) == cfg_reader.config.base_model
tensor_params[model] = {}
cfg_m = cfg_reader.for_tensor(name_in)
for p in tensor_merge_method.tensor_parameters():
tensor_params[model][p.name] = cfg_m.parameter(
p.name, model=model, required=p.required, default=p.default_value
p.name,
model=model,
required=p.required and not is_base,
default=p.default_value,
)

gather_tensors = GatherTensors(
Expand All @@ -111,8 +165,12 @@ def plan_tensor(
tensor_task = tensor_merge_method.make_task(
output_tensor_name=name,
tensors=gather_tensors,
parameters=global_params,
tensor_parameters=tensor_params,
parameters=ImmutableMap(data=global_params),
tensor_parameters=ImmutableMap(
data={
key: ImmutableMap(data=tensor_params[key]) for key in tensor_params
}
),
base_model=base_model,
)
save_task = SaveTensor(
Expand Down Expand Up @@ -165,18 +223,23 @@ def plan_slice(self, definition: OutputSliceDefinition):
definition.sources,
layer_offset=idx,
t=t,
cfg_reader=cfg_reader.for_in_slices(definition.sources),
cfg_reader=cfg_reader,
)

def plan(self):
self.normalize_config()
self._tasks = []

for weight_name in self.arch_info.pre_weights():
self.plan_tensor(
weight_name,
[weight_name] * len(self.config.slices[0].sources),
[ModelReference.parse(s.model) for s in self.config.slices[0].sources],
ConfigReader(config=self.config, t=0, tensor_name=weight_name),
ConfigReader(
config=self.config,
t=0,
tensor_name=weight_name,
).for_out_slice(self.config.slices[0]),
)

for out_slice in self.config.slices:
Expand All @@ -187,7 +250,11 @@ def plan(self):
weight_name,
[weight_name] * len(self.config.slices[-1].sources),
[ModelReference.parse(s.model) for s in self.config.slices[-1].sources],
ConfigReader(config=self.config, t=1, tensor_name=weight_name),
ConfigReader(
config=self.config,
t=1,
tensor_name=weight_name,
).for_out_slice(self.config.slices[-1]),
)

self._tasks.append(
Expand Down
Loading

0 comments on commit b184ae2

Please sign in to comment.