From 4c3532cd1f7a21bfefe032212c8cd50e5e685ac2 Mon Sep 17 00:00:00 2001 From: "Charles O. Goddard" Date: Mon, 15 Jul 2024 12:59:06 -0700 Subject: [PATCH] Tokenizer merging overhaul (#334) Rewrite the tokenizer merging logic to support all merge methods and allow more customization of behavior. The previous implementation of tokenizer merging always used either linear or slerp to combine the embedding/LM head parameters. This was to avoid the complexity that would be required to make all merge methods support tensors that potentially have invalid or masked out values. It works okay for some cases but wasn't a general solution. In this implementation, instead of overriding the merge method for embed/lm_head a preprocessing step remaps them to the vocabulary used by the output model. These (now appropriately sized and ordered) tensors are then merged normally. The selection of embedding values for tokens not normally present in a model is where things get slightly tricky. By default a set of heuristics that I think are sane are applied. For a given token and model, if the token is not present in the model's original tokenizer: * If the base model has this token present, the base model's embedding is used * If only one model in the merge has the token, that model's embedding is used * Otherwise, the average of all embeddings for the token is assumed as a default value This can also be overridden on a per-token level. For example: ```yaml merge_method: dare_ties base_model: ... models: - model: some_chatml_model - model: some_weird_model - model: some_model tokenizer: source: union tokens: # if model doesn't have <|im_start|>, use embedding from some_chatml_model <|im_start|>: source: some_chatml_model # use embedding of <|special|> from some_weird_model for *all* models <|special|>: source: some_weird_model force: true # output tokenizer will have <|renamed_token|> with embedding of <|original_token|> # from some_model <|renamed_token|>: source: kind: model_token model: some_model token: <|original_token|> force: true ``` A practical example would be for merging two Llama 3 models, one using the Llama 3 Instruct prompt format and one using chatml, trying to preserve the ability to use both formats: ```yaml tokenizer: source: union tokens: <|im_start|>: source: chatml_model <|im_end|>: source: chatml_model <|start_header_id|>: source: llama3_model force: true <|end_header_id|>: source: llama3_model force: true <|eot_id|>: source: llama3_model force: true ``` --- mergekit/config.py | 14 +- mergekit/merge_methods/base.py | 8 +- .../generalized_task_arithmetic.py | 11 +- mergekit/merge_methods/linear.py | 11 +- mergekit/merge_methods/model_stock.py | 11 +- mergekit/merge_methods/passthrough.py | 11 +- mergekit/merge_methods/slerp.py | 11 +- mergekit/merge_methods/tokenizer_permute.py | 11 +- mergekit/plan.py | 32 ++- mergekit/tokenizer/__init__.py | 20 ++ mergekit/{tokenizer.py => tokenizer/build.py} | 57 +++--- mergekit/tokenizer/config.py | 51 +++++ mergekit/tokenizer/embed.py | 182 ++++++++++++++++++ pyproject.toml | 2 + tests/test_tokenizer.py | 145 ++++++++++++-- 15 files changed, 498 insertions(+), 79 deletions(-) create mode 100644 mergekit/tokenizer/__init__.py rename mergekit/{tokenizer.py => tokenizer/build.py} (87%) create mode 100644 mergekit/tokenizer/config.py create mode 100644 mergekit/tokenizer/embed.py diff --git a/mergekit/config.py b/mergekit/config.py index 28999a00..9a5e8efd 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -17,9 +17,10 @@ import yaml from pydantic import BaseModel, model_validator -from typing_extensions import TypeAlias +from typing_extensions import Literal, TypeAlias from mergekit.common import ModelReference +from mergekit.tokenizer.config import TokenizerConfig ScalarOrGradient: TypeAlias = Union[float, List[float]] @@ -88,7 +89,10 @@ class MergeConfiguration(BaseModel): parameters: Optional[Dict[str, ParameterSetting]] = None base_model: Optional[ModelReference] = None dtype: Optional[str] = None - tokenizer_source: Optional[str] = None + tokenizer_source: Union[ + Literal["union"], Literal["base"], ModelReference, None + ] = None + tokenizer: Optional[TokenizerConfig] = None out_dtype: Optional[str] = None def referenced_models(self) -> List[ModelReference]: @@ -110,6 +114,12 @@ def validate_inputs(self): raise RuntimeError("Must specify either output slices or models to merge") return self + @model_validator(mode="after") + def validate_tokenizer(self): + if self.tokenizer_source and self.tokenizer: + raise RuntimeError("Cannot specify both tokenizer_source and tokenizer") + return self + def to_yaml(self) -> str: return yaml.dump( self.model_dump(exclude_defaults=True, mode="json"), diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index 853fbf31..917ed089 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -14,14 +14,18 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from pydantic import BaseModel +from typing_extensions import TypeAlias from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors +from mergekit.tokenizer import PermutedEmbeddings + +MergeTensorInput: TypeAlias = Union[GatherTensors, PermutedEmbeddings] class ConfigParameterDef(BaseModel): @@ -42,7 +46,7 @@ def make_task( self, *, output_weight: WeightInfo, - tensors: GatherTensors, + tensors: MergeTensorInput, parameters: ImmutableMap[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], base_model: Optional[ModelReference], diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 02c1277f..af09c8bb 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -24,8 +24,11 @@ from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) from mergekit.sparsify import SparsificationMethod, sparsify @@ -68,7 +71,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: def make_task( self, output_weight: WeightInfo, - tensors: GatherTensors, + tensors: MergeTensorInput, base_model: Optional[ModelReference], parameters: ImmutableMap[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], @@ -87,7 +90,7 @@ def make_task( class GTATask(Task[torch.Tensor]): method: GeneralizedTaskArithmeticMerge - tensors: GatherTensors + tensors: MergeTensorInput base_model: ModelReference weight_info: WeightInfo tensor_parameters: ImmutableMap[ModelReference, Any] diff --git a/mergekit/merge_methods/linear.py b/mergekit/merge_methods/linear.py index 81826a97..48224bb8 100644 --- a/mergekit/merge_methods/linear.py +++ b/mergekit/merge_methods/linear.py @@ -20,13 +20,16 @@ from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) from mergekit.merge_methods.rectify_embed import rectify_embed_sizes class LinearMergeTask(Task[torch.Tensor]): - gather_tensors: GatherTensors + gather_tensors: MergeTensorInput tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] normalize: bool weight_info: WeightInfo @@ -81,7 +84,7 @@ def make_task( self, *, output_weight: WeightInfo, - tensors: GatherTensors, + tensors: MergeTensorInput, parameters: Dict[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], **_kwargs, diff --git a/mergekit/merge_methods/model_stock.py b/mergekit/merge_methods/model_stock.py index 5130f3ea..94b1e05b 100644 --- a/mergekit/merge_methods/model_stock.py +++ b/mergekit/merge_methods/model_stock.py @@ -21,13 +21,16 @@ from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) from mergekit.merge_methods.rectify_embed import rectify_embed_sizes class ModelStockMergeTask(Task[torch.Tensor]): - gather_tensors: GatherTensors + gather_tensors: MergeTensorInput base_model: ModelReference weight_info: WeightInfo filter_wise: bool = False @@ -120,7 +123,7 @@ def make_task( self, *, output_weight: WeightInfo, - tensors: GatherTensors, + tensors: MergeTensorInput, base_model: Optional[ModelReference], parameters: ImmutableMap[str, Any], **_kwargs, diff --git a/mergekit/merge_methods/passthrough.py b/mergekit/merge_methods/passthrough.py index 8e4ba14e..62b0bf12 100644 --- a/mergekit/merge_methods/passthrough.py +++ b/mergekit/merge_methods/passthrough.py @@ -19,12 +19,15 @@ from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) class PassthroughMergeTask(Task[torch.Tensor]): - gather_tensors: GatherTensors + gather_tensors: MergeTensorInput tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] def arguments(self) -> Dict[str, Task]: @@ -52,7 +55,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: def make_task( self, *, - tensors: GatherTensors, + tensors: MergeTensorInput, tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], **kwargs, ) -> Task: diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index dd89d09e..d33dd5a9 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -21,13 +21,16 @@ from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) from mergekit.merge_methods.rectify_embed import rectify_embed_sizes class SlerpTask(Task[torch.Tensor]): - gather_tensors: GatherTensors + gather_tensors: MergeTensorInput base_model: ModelReference t: float weight_info: WeightInfo @@ -75,7 +78,7 @@ def make_task( self, *, output_weight: WeightInfo, - tensors: GatherTensors, + tensors: MergeTensorInput, parameters: ImmutableMap[str, Any], base_model: Optional[ModelReference], **_kwargs, diff --git a/mergekit/merge_methods/tokenizer_permute.py b/mergekit/merge_methods/tokenizer_permute.py index 208fb589..07c6f9c5 100644 --- a/mergekit/merge_methods/tokenizer_permute.py +++ b/mergekit/merge_methods/tokenizer_permute.py @@ -20,15 +20,18 @@ from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task -from mergekit.io.tasks import GatherTensors -from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) from mergekit.merge_methods.slerp import slerp from mergekit.tokenizer import BuildTokenizer, TokenizerInfo class TokenizerPermutationMergeTask(Task[torch.Tensor]): tokenizer_task: BuildTokenizer - gather_tensors: GatherTensors + gather_tensors: MergeTensorInput base_model: Optional[ModelReference] use_slerp: bool slerp_t: Optional[float] @@ -134,7 +137,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: def make_task( self, *, - tensors: GatherTensors, + tensors: MergeTensorInput, parameters: Dict[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], base_model: Optional[ModelReference], diff --git a/mergekit/plan.py b/mergekit/plan.py index 7bed6032..bdcd7004 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -40,9 +40,8 @@ TensorWriterTask, ) from mergekit.merge_methods import MergeMethod -from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge from mergekit.options import MergeOptions -from mergekit.tokenizer import BuildTokenizer +from mergekit.tokenizer import BuildTokenizer, PermutedEmbeddings class MergePlanner: @@ -68,12 +67,18 @@ def __init__( self.out_model_config = out_model_config self._method = merge_methods.get(config.merge_method) - if config.tokenizer_source: + token_cfg = {} + tokenizer_source = config.tokenizer_source + if config.tokenizer is not None: + token_cfg = config.tokenizer.tokens or {} + tokenizer_source = config.tokenizer.source + if tokenizer_source is not None: self._tokenizer_task = BuildTokenizer( base_model=config.base_model, referenced_models=tuple(config.referenced_models()), - tokenizer_source=config.tokenizer_source, + tokenizer_source=tokenizer_source, trust_remote_code=options.trust_remote_code, + add_tokens=tuple(token_cfg.keys()), ) @lru_cache @@ -143,11 +148,6 @@ def plan_tensor( return tensor_merge_method = self._method - if self._tokenizer_task and weight.is_embed: - tensor_merge_method = TokenizerPermutationMerge( - tokenizer_task=self._tokenizer_task - ) - cfg_g = cfg_reader.for_tensor(weight.name) global_params = {} for p in tensor_merge_method.parameters(): @@ -176,9 +176,21 @@ def plan_tensor( device="cuda" if self.options.read_to_gpu else None, ) + tensor_input_task = gather_tensors + if self._tokenizer_task and weight.is_embed: + token_cfg = {} + if cfg_reader.config.tokenizer: + token_cfg = cfg_reader.config.tokenizer.tokens + tensor_input_task = PermutedEmbeddings( + gather_tensors=gather_tensors, + tokenizer_task=self._tokenizer_task, + tokens=token_cfg, + base_model=base_model, + ) + tensor_task = tensor_merge_method.make_task( output_weight=weight, - tensors=gather_tensors, + tensors=tensor_input_task, parameters=ImmutableMap(data=global_params), tensor_parameters=ImmutableMap( data={ diff --git a/mergekit/tokenizer/__init__.py b/mergekit/tokenizer/__init__.py new file mode 100644 index 00000000..cff42a46 --- /dev/null +++ b/mergekit/tokenizer/__init__.py @@ -0,0 +1,20 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo +from mergekit.tokenizer.config import TokenizerConfig +from mergekit.tokenizer.embed import PermutedEmbeddings + +__all__ = ["BuildTokenizer", "TokenizerInfo", "TokenizerConfig", "PermutedEmbeddings"] diff --git a/mergekit/tokenizer.py b/mergekit/tokenizer/build.py similarity index 87% rename from mergekit/tokenizer.py rename to mergekit/tokenizer/build.py index a3a0f858..fb9f9d9c 100644 --- a/mergekit/tokenizer.py +++ b/mergekit/tokenizer/build.py @@ -16,14 +16,14 @@ import json import logging import tempfile -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import tokenizers import tokenizers.models -import torch import tqdm import transformers from pydantic import BaseModel +from typing_extensions import Literal from mergekit.common import ModelPath, ModelReference from mergekit.graph import Task @@ -169,12 +169,19 @@ def build_union_tokenizer( return res +class TokenizerInfo(BaseModel, arbitrary_types_allowed=True): + tokenizer: transformers.PreTrainedTokenizerBase + permutations: Dict[ModelReference, Dict[int, int]] + original_vocabs: Dict[ModelReference, Dict[str, int]] + + def build_tokenizer( base_model: Optional[ModelReference], referenced_models: List[ModelReference], - tokenizer_source: str, + tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference], trust_remote_code: bool, -) -> Tuple[transformers.PreTrainedTokenizer, Dict[ModelReference, torch.IntTensor]]: + add_tokens: Optional[List[str]] = None, +) -> TokenizerInfo: if base_model is None: base_model = referenced_models[0] if base_model is None: @@ -208,21 +215,25 @@ def build_tokenizer( logging.info("Building output tokenizer") # build final vocabulary - if tokenizer_source == "base": + if isinstance(tokenizer_source, ModelReference): + tokenizer_out = transformers.AutoTokenizer.from_pretrained( + tokenizer_source.model.path, + revision=tokenizer_source.model.revision, + trust_remote_code=trust_remote_code, + ) + elif tokenizer_source == "base": # it done tokenizer_out = tokenizer_base elif tokenizer_source == "union": tokenizer_out = build_union_tokenizer( tokenizer_base, tokenizers, trust_remote_code=trust_remote_code ) - elif tokenizer_source.startswith("model:"): - tokenizer_out = transformers.AutoTokenizer.from_pretrained( - tokenizer_source[len("model:") :], - trust_remote_code=trust_remote_code, - ) else: raise RuntimeError(f"Unimplemented tokenizer source: {tokenizer_source}") + for tok in add_tokens: + tokenizer_out.add_tokens(tok) + vocab_out = tokenizer_out.get_vocab() logging.info("Building permutations") @@ -259,28 +270,28 @@ def build_tokenizer( del pbar - return tokenizer_out, permutations - - -class TokenizerInfo(BaseModel, arbitrary_types_allowed=True): - tokenizer: transformers.PreTrainedTokenizerBase - permutations: Optional[Dict[ModelReference, Dict[int, int]]] + return TokenizerInfo( + tokenizer=tokenizer_out, + permutations=permutations, + original_vocabs={model: tok.get_vocab() for model, tok in tokenizers.items()}, + ) class BuildTokenizer(Task[TokenizerInfo]): base_model: Optional[ModelReference] referenced_models: Tuple[ModelReference, ...] - tokenizer_source: str + tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference] + add_tokens: Optional[Tuple[str, ...]] trust_remote_code: bool = False def arguments(self) -> Dict[str, Task]: return {} def execute(self, **_kwargs) -> TokenizerInfo: - tokenizer, permutations = build_tokenizer( - self.base_model, - self.referenced_models, - self.tokenizer_source, - self.trust_remote_code, + return build_tokenizer( + base_model=self.base_model, + referenced_models=self.referenced_models, + tokenizer_source=self.tokenizer_source, + trust_remote_code=self.trust_remote_code, + add_tokens=self.add_tokens, ) - return TokenizerInfo(tokenizer=tokenizer, permutations=permutations) diff --git a/mergekit/tokenizer/config.py b/mergekit/tokenizer/config.py new file mode 100644 index 00000000..94208385 --- /dev/null +++ b/mergekit/tokenizer/config.py @@ -0,0 +1,51 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Dict, Optional, Union + +import pydantic +from pydantic import BaseModel +from typing_extensions import Literal + +from mergekit.common import ModelReference + + +class ModelTokenEmbedding(BaseModel, frozen=True): + kind: Literal["model_token"] + model: ModelReference + token_id: Optional[int] = None + token: Optional[str] = None + + @pydantic.model_validator(mode="after") + def validate_token(self): + if self.token_id is None and self.token is None: + raise ValueError("token_id or token must be specified") + if self.token_id is not None and self.token is not None: + raise ValueError("only one of token_id or token may be specified") + return self + + +class ZeroEmbedding(BaseModel, frozen=True): + kind: Literal["zero"] + + +class TokenEmbeddingConfig(BaseModel, frozen=True): + source: Union[ModelTokenEmbedding, ZeroEmbedding, ModelReference, None] = None + force: bool = False + + +class TokenizerConfig(BaseModel, frozen=True): + source: Union[ModelReference, Literal["union"], Literal["base"]] = "union" + tokens: Optional[Dict[str, TokenEmbeddingConfig]] = None diff --git a/mergekit/tokenizer/embed.py b/mergekit/tokenizer/embed.py new file mode 100644 index 00000000..3cdb1840 --- /dev/null +++ b/mergekit/tokenizer/embed.py @@ -0,0 +1,182 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from typing import Dict, Optional + +import torch + +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo +from mergekit.tokenizer.config import ( + ModelTokenEmbedding, + TokenEmbeddingConfig, + ZeroEmbedding, +) + + +class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]): + gather_tensors: GatherTensors + tokenizer_task: BuildTokenizer + tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]] + base_model: Optional[ModelReference] + + def arguments(self) -> Dict[str, Task]: + return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors} + + def execute( + self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor] + ) -> Dict[ModelReference, torch.Tensor]: + tokenizer = tokenizer_info.tokenizer + permutations = tokenizer_info.permutations + + models = set(tensors.keys()) + if self.base_model: + models.add(self.base_model) + models = list(models) + + vocab = tokenizer.get_vocab() + vocab_size = len(vocab) + embed_size = tensors[models[0]].shape[1] + assert all( + t.shape[1] == embed_size for t in tensors.values() + ), "Embedding sizes must match" + + dtype = tensors[models[0]].dtype + device = tensors[models[0]].device + + token_configs = dict(**self.tokens) or {} + tokens_to_average = self.assign_embedding_sources( + permutations, models, vocab, token_configs + ) + + default_embeds = {} + for token, token_id in vocab.items(): + embed = torch.zeros(embed_size, dtype=dtype, device=device) + if token in tokens_to_average: + count = 0 + for model in models: + p = permutations[model] + if p[token_id] < 0: + continue + embed += tensors[model][p[token_id]] + count += 1 + embed /= count + elif cfg := token_configs.get(token, None): + cfg: TokenEmbeddingConfig + embed = self.compute_default_embedding( + tokenizer_info, tensors, permutations, token, token_id, cfg + ) + else: + continue + default_embeds[token] = embed + + result = {} + for model in models: + p = permutations[model] + old_embed = tensors[model] + new_embed = torch.zeros( + (vocab_size, embed_size), dtype=dtype, device=device + ) + for token, token_id in vocab.items(): + force = False + if token in token_configs: + force = token_configs[token].force + + if p[token_id] >= 0 and not force: + new_embed[token_id, :] = old_embed[p[token_id]] + elif token in default_embeds: + new_embed[token_id, :] = default_embeds[token] + else: + logging.error( + f"No embedding for token {repr(token)} in model {model}!" + ) + result[model] = new_embed + + return result + + def assign_embedding_sources( + self, + permutations: Dict[ModelReference, Dict[int, int]], + models: list[ModelReference], + vocab: Dict[str, int], + token_configs: Dict[str, TokenEmbeddingConfig], + ): + permutation_list = [permutations[model] for model in models] + + tokens_to_average = set() + # find tokens that are only present in one model + for token, token_id in vocab.items(): + if token in token_configs: + continue + + has_token = [p[token_id] >= 0 for p in permutation_list] + num_present = sum(int(x) for x in has_token) + if num_present == 1: + donor_model = models[has_token.index(True)] + token_configs[token] = TokenEmbeddingConfig(source=donor_model) + continue + + if num_present == 0: + token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding()) + logging.warning(f"Token {repr(token)} not found in any model") + continue + + if num_present > 0 and self.base_model is not None: + if permutations[self.base_model][token_id] >= 0: + token_configs[token] = TokenEmbeddingConfig(source=self.base_model) + continue + + tokens_to_average.add(token) + return tokens_to_average + + def compute_default_embedding( + self, + tokenizer_info: TokenizerInfo, + tensors: Dict[ModelReference, torch.Tensor], + permutations: Dict[ModelReference, Dict[int, int]], + token: str, + token_id: int, + cfg: TokenEmbeddingConfig, + ) -> torch.Tensor: + if isinstance(cfg.source, ZeroEmbedding): + pass + elif isinstance(cfg.source, ModelTokenEmbedding): + model = cfg.source.model + assert ( + model in permutations + ), f"Model {model} referenced but not part of merge" + p = permutations[model] + src_token_id = cfg.source.token_id + if src_token_id is None: + src_token = cfg.source.token + assert ( + src_token in tokenizer_info.original_vocabs[model] + ), f"Token {repr(src_token)} not found in model {model}" + src_token_id = tokenizer_info.original_vocabs[model][src_token] + assert ( + src_token_id >= 0 and src_token_id < tensors[model].shape[0] + ), f"Token ID {src_token_id} out of range for model {model}" + embed = tensors[model][src_token_id] + elif isinstance(cfg.source, ModelReference): + model = cfg.source + p = permutations[model] + assert p[token_id] >= 0, f"Token {repr(token)} not found in model {model}" + embed = tensors[model][p[token_id]] + else: + raise NotImplementedError(cfg) + return embed diff --git a/pyproject.toml b/pyproject.toml index 7cf524a8..b612908e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,5 +72,7 @@ include = '\.pyi?$' minversion = "6.0" filterwarnings = [ "ignore::pydantic.PydanticDeprecatedSince20:huggingface_hub.*:", + "ignore::FutureWarning:huggingface_hub.*:", + "ignore:(read_text|open_text|contents) is deprecated:DeprecationWarning", # yes i know, but files() doesn't exist in 3.8 ] testpaths = ["tests"] diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 93b33925..17fafcc8 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,14 +1,17 @@ import json import os import tempfile -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import pytest import tokenizers +import torch from common import make_picollama, run_and_check_merge from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase -from mergekit.config import InputModelDefinition, MergeConfiguration, ParameterSetting +from mergekit.config import InputModelDefinition, MergeConfiguration +from mergekit.io import LazyTensorLoader +from mergekit.tokenizer import TokenizerConfig @pytest.fixture(scope="session") @@ -87,6 +90,23 @@ def _cb(model_path: str): return _cb +class ModelEmbeddings: + embed_tokens: torch.Tensor + vocab: Dict[str, int] + + def __init__(self, model_path: str): + tokenizer = LlamaTokenizerFast.from_pretrained(model_path) + loader = LazyTensorLoader.from_disk(model_path) + self.embed_tokens = loader.get_tensor("model.embed_tokens.weight") + self.vocab = tokenizer.get_vocab() + + def token_embedding(self, token: str) -> Optional[torch.Tensor]: + idx = self.vocab.get(token) + if idx is None: + return None + return self.embed_tokens[idx, :] + + class TestTokenizerMerges: def test_legacy_mode(self, model_base: str, model_padded: str, model_chatml: str): config = self.make_config( @@ -115,23 +135,39 @@ def test_source_union(self, model_base: str, model_padded: str, model_chatml: st tokenizer_source="union", ) - # output should have all tokens used by any model - # but not include any unused tokens - run_and_check_merge( - config, - validate=check_tokenizer( + def _check_embed(model_path: str): + # output should have all tokens used by any model + # but not include any unused tokens + check_tokenizer( expected_size=66, expected_added_ct=5, must_contain=["<|im_start|>", "<|im_end|>"], must_not_contain=[f"" for idx in range(4)], - ), + )(model_path) + emb_out = ModelEmbeddings(model_path) + emb_chatml = ModelEmbeddings(model_chatml) + + assert torch.allclose( + emb_out.token_embedding("<|im_start|>"), + emb_chatml.token_embedding("<|im_start|>"), + ), "Token <|im_start|> should be from model_chatml" + assert torch.allclose( + emb_out.token_embedding("<|im_end|>"), + emb_chatml.token_embedding("<|im_end|>"), + atol=1e-3, + rtol=1e-4, + ), "Token <|im_end|> should be from model_chatml" + + run_and_check_merge( + config, + validate=_check_embed, ) def test_source_model(self, model_base: str, model_padded: str, model_chatml: str): config = self.make_config( [model_base, model_padded, model_chatml], base_model=model_base, - tokenizer_source="model:" + model_chatml, + tokenizer_source=model_chatml, ) # tokenizer should match model_chatml run_and_check_merge( @@ -147,8 +183,7 @@ def test_slerp_union(self, model_base: str, model_chatml: str): base_model=model_base, tokenizer_source="union", merge_method="slerp", - embed_slerp=True, - t="0.5", + t=0.5, ) run_and_check_merge( @@ -159,19 +194,92 @@ def test_slerp_union(self, model_base: str, model_chatml: str): ), ) + def test_force_token(self, model_base: str, model_chatml: str): + config = self.make_config( + [model_base, model_chatml], + base_model=model_base, + merge_method="linear", + tokenizer_config=TokenizerConfig( + source="union", + tokens={ + "_tok_10": {"source": model_chatml, "force": True}, + "_tok_11": {"source": model_base, "force": True}, + }, + ), + ) + + def _check_embed(model_path: str): + check_tokenizer( + expected_size=66, must_contain=["<|im_start|>", "<|im_end|>"] + )(model_path) + emb_out = ModelEmbeddings(model_path) + emb_base = ModelEmbeddings(model_base) + emb_chatml = ModelEmbeddings(model_chatml) + + assert torch.allclose( + emb_out.token_embedding("_tok_10"), + emb_chatml.token_embedding("_tok_10"), + ), "Token _tok_10 should be from model_chatml" + assert torch.allclose( + emb_out.token_embedding("_tok_11"), + emb_base.token_embedding("_tok_11"), + ), "Token _tok_11 should be from model_base" + + run_and_check_merge(config, validate=_check_embed) + + def test_model_token_id(self, model_base: str, model_chatml: str): + config = self.make_config( + [model_base, model_chatml], + base_model=model_base, + merge_method="linear", + tokenizer_config=TokenizerConfig( + source="base", + tokens={ + "_tok_20": { + "source": { + "kind": "model_token", + "model": model_chatml, + "token_id": 64, + }, + "force": True, + }, + "_tok_21": { + "source": { + "kind": "model_token", + "model": model_base, + "token": "", + }, + "force": True, + }, + }, + ), + ) + + def _check_embed(model_path: str): + check_tokenizer(expected_size=64, must_contain=["_tok_10"])(model_path) + emb_out = ModelEmbeddings(model_path) + emb_base = ModelEmbeddings(model_base) + emb_chatml = ModelEmbeddings(model_chatml) + + assert torch.allclose( + emb_out.token_embedding("_tok_20"), emb_chatml.embed_tokens[64, :] + ), "Token _tok_20 should be == model_chatml token 64" + assert torch.allclose( + emb_out.token_embedding("_tok_21"), emb_base.token_embedding("") + ), "Token _tok_21 should be == model_base " + + run_and_check_merge(config, validate=_check_embed) + def make_config( self, models: List[str], base_model: Optional[str] = None, merge_method: str = "linear", tokenizer_source: Optional[str] = None, - embed_slerp: bool = False, - t: Optional[ParameterSetting] = None, + t: Optional[float] = None, + tokenizer_config: Optional[TokenizerConfig] = None, ): - parameters = {"embed_slerp": embed_slerp} - if t is not None: - parameters["t"] = t - + parameters = {"t": t} if t is not None else {} config = MergeConfiguration( merge_method=merge_method, base_model=base_model, @@ -182,8 +290,9 @@ def make_config( ) for m in models ], - dtype="bfloat16", + dtype="float32", tokenizer_source=tokenizer_source, parameters=parameters, + tokenizer=tokenizer_config, ) return config