Skip to content

Commit

Permalink
Merge branch 'main' into abm
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space authored Jul 15, 2024
2 parents fab765b + 4c3532c commit 2fd3e95
Show file tree
Hide file tree
Showing 15 changed files with 498 additions and 79 deletions.
14 changes: 12 additions & 2 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import yaml
from pydantic import BaseModel, model_validator
from typing_extensions import TypeAlias
from typing_extensions import Literal, TypeAlias

from mergekit.common import ModelReference
from mergekit.tokenizer.config import TokenizerConfig

ScalarOrGradient: TypeAlias = Union[float, List[float]]

Expand Down Expand Up @@ -88,7 +89,10 @@ class MergeConfiguration(BaseModel):
parameters: Optional[Dict[str, ParameterSetting]] = None
base_model: Optional[ModelReference] = None
dtype: Optional[str] = None
tokenizer_source: Optional[str] = None
tokenizer_source: Union[
Literal["union"], Literal["base"], ModelReference, None
] = None
tokenizer: Optional[TokenizerConfig] = None
out_dtype: Optional[str] = None

def referenced_models(self) -> List[ModelReference]:
Expand All @@ -110,6 +114,12 @@ def validate_inputs(self):
raise RuntimeError("Must specify either output slices or models to merge")
return self

@model_validator(mode="after")
def validate_tokenizer(self):
if self.tokenizer_source and self.tokenizer:
raise RuntimeError("Cannot specify both tokenizer_source and tokenizer")
return self

def to_yaml(self) -> str:
return yaml.dump(
self.model_dump(exclude_defaults=True, mode="json"),
Expand Down
8 changes: 6 additions & 2 deletions mergekit/merge_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from pydantic import BaseModel
from typing_extensions import TypeAlias

from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.tokenizer import PermutedEmbeddings

MergeTensorInput: TypeAlias = Union[GatherTensors, PermutedEmbeddings]


class ConfigParameterDef(BaseModel):
Expand All @@ -42,7 +46,7 @@ def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
tensors: MergeTensorInput,
parameters: ImmutableMap[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
base_model: Optional[ModelReference],
Expand Down
11 changes: 7 additions & 4 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.sparsify import SparsificationMethod, sparsify


Expand Down Expand Up @@ -68,7 +71,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
def make_task(
self,
output_weight: WeightInfo,
tensors: GatherTensors,
tensors: MergeTensorInput,
base_model: Optional[ModelReference],
parameters: ImmutableMap[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
Expand All @@ -87,7 +90,7 @@ def make_task(

class GTATask(Task[torch.Tensor]):
method: GeneralizedTaskArithmeticMerge
tensors: GatherTensors
tensors: MergeTensorInput
base_model: ModelReference
weight_info: WeightInfo
tensor_parameters: ImmutableMap[ModelReference, Any]
Expand Down
11 changes: 7 additions & 4 deletions mergekit/merge_methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes


class LinearMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
gather_tensors: MergeTensorInput
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
normalize: bool
weight_info: WeightInfo
Expand Down Expand Up @@ -81,7 +84,7 @@ def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
tensors: MergeTensorInput,
parameters: Dict[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**_kwargs,
Expand Down
11 changes: 7 additions & 4 deletions mergekit/merge_methods/model_stock.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes


class ModelStockMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
gather_tensors: MergeTensorInput
base_model: ModelReference
weight_info: WeightInfo
filter_wise: bool = False
Expand Down Expand Up @@ -120,7 +123,7 @@ def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
tensors: MergeTensorInput,
base_model: Optional[ModelReference],
parameters: ImmutableMap[str, Any],
**_kwargs,
Expand Down
11 changes: 7 additions & 4 deletions mergekit/merge_methods/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)


class PassthroughMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
gather_tensors: MergeTensorInput
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]

def arguments(self) -> Dict[str, Task]:
Expand Down Expand Up @@ -52,7 +55,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
def make_task(
self,
*,
tensors: GatherTensors,
tensors: MergeTensorInput,
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**kwargs,
) -> Task:
Expand Down
11 changes: 7 additions & 4 deletions mergekit/merge_methods/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes


class SlerpTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
gather_tensors: MergeTensorInput
base_model: ModelReference
t: float
weight_info: WeightInfo
Expand Down Expand Up @@ -75,7 +78,7 @@ def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
tensors: MergeTensorInput,
parameters: ImmutableMap[str, Any],
base_model: Optional[ModelReference],
**_kwargs,
Expand Down
11 changes: 7 additions & 4 deletions mergekit/merge_methods/tokenizer_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.slerp import slerp
from mergekit.tokenizer import BuildTokenizer, TokenizerInfo


class TokenizerPermutationMergeTask(Task[torch.Tensor]):
tokenizer_task: BuildTokenizer
gather_tensors: GatherTensors
gather_tensors: MergeTensorInput
base_model: Optional[ModelReference]
use_slerp: bool
slerp_t: Optional[float]
Expand Down Expand Up @@ -134,7 +137,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
def make_task(
self,
*,
tensors: GatherTensors,
tensors: MergeTensorInput,
parameters: Dict[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
base_model: Optional[ModelReference],
Expand Down
32 changes: 22 additions & 10 deletions mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@
TensorWriterTask,
)
from mergekit.merge_methods import MergeMethod
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge
from mergekit.options import MergeOptions
from mergekit.tokenizer import BuildTokenizer
from mergekit.tokenizer import BuildTokenizer, PermutedEmbeddings


class MergePlanner:
Expand All @@ -68,12 +67,18 @@ def __init__(
self.out_model_config = out_model_config
self._method = merge_methods.get(config.merge_method)

if config.tokenizer_source:
token_cfg = {}
tokenizer_source = config.tokenizer_source
if config.tokenizer is not None:
token_cfg = config.tokenizer.tokens or {}
tokenizer_source = config.tokenizer.source
if tokenizer_source is not None:
self._tokenizer_task = BuildTokenizer(
base_model=config.base_model,
referenced_models=tuple(config.referenced_models()),
tokenizer_source=config.tokenizer_source,
tokenizer_source=tokenizer_source,
trust_remote_code=options.trust_remote_code,
add_tokens=tuple(token_cfg.keys()),
)

@lru_cache
Expand Down Expand Up @@ -143,11 +148,6 @@ def plan_tensor(
return

tensor_merge_method = self._method
if self._tokenizer_task and weight.is_embed:
tensor_merge_method = TokenizerPermutationMerge(
tokenizer_task=self._tokenizer_task
)

cfg_g = cfg_reader.for_tensor(weight.name)
global_params = {}
for p in tensor_merge_method.parameters():
Expand Down Expand Up @@ -176,9 +176,21 @@ def plan_tensor(
device="cuda" if self.options.read_to_gpu else None,
)

tensor_input_task = gather_tensors
if self._tokenizer_task and weight.is_embed:
token_cfg = {}
if cfg_reader.config.tokenizer:
token_cfg = cfg_reader.config.tokenizer.tokens
tensor_input_task = PermutedEmbeddings(
gather_tensors=gather_tensors,
tokenizer_task=self._tokenizer_task,
tokens=token_cfg,
base_model=base_model,
)

tensor_task = tensor_merge_method.make_task(
output_weight=weight,
tensors=gather_tensors,
tensors=tensor_input_task,
parameters=ImmutableMap(data=global_params),
tensor_parameters=ImmutableMap(
data={
Expand Down
20 changes: 20 additions & 0 deletions mergekit/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo
from mergekit.tokenizer.config import TokenizerConfig
from mergekit.tokenizer.embed import PermutedEmbeddings

__all__ = ["BuildTokenizer", "TokenizerInfo", "TokenizerConfig", "PermutedEmbeddings"]
Loading

0 comments on commit 2fd3e95

Please sign in to comment.