From b184ae21b4ab86b7e927c565f26f05c04c932c10 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 30 Dec 2023 18:15:39 -0800 Subject: [PATCH] Add merge tests and fix various problems --- mergekit/config.py | 40 +++---- mergekit/merge_methods/base.py | 8 +- .../generalized_task_arithmetic.py | 12 +- mergekit/merge_methods/linear.py | 6 +- mergekit/merge_methods/passthrough.py | 4 +- mergekit/merge_methods/slerp.py | 16 +-- mergekit/merge_methods/tokenizer_permute.py | 4 +- mergekit/plan.py | 81 +++++++++++-- mergekit/tasks.py | 2 + tests/test_merges.py | 108 ++++++++++++++++++ 10 files changed, 228 insertions(+), 53 deletions(-) create mode 100644 tests/test_merges.py diff --git a/mergekit/config.py b/mergekit/config.py index 24b2e3f3..995bf4aa 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -116,7 +116,6 @@ class ConfigReader(BaseModel): t: float tensor_name: Optional[str] = None slice_out: Optional[OutputSliceDefinition] = None - slices_in: Optional[List[InputSliceDefinition]] = None @property def base_model(self) -> Optional[ModelReference]: @@ -135,16 +134,6 @@ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader": t=self.t, tensor_name=self.tensor_name, slice_out=slice, - slices_in=self.slices_in, - ) - - def for_in_slices(self, slices: List[InputSliceDefinition]) -> "ConfigReader": - return ConfigReader( - config=self.config, - t=self.t, - tensor_name=self.tensor_name, - slice_out=self.slice_out, - slices_in=slices, ) def for_tensor(self, tensor_name: str) -> "ConfigReader": @@ -153,7 +142,6 @@ def for_tensor(self, tensor_name: str) -> "ConfigReader": t=self.t, tensor_name=tensor_name, slice_out=self.slice_out, - slices_in=self.slices_in, ) def with_t(self, t: float) -> "ConfigReader": @@ -162,7 +150,6 @@ def with_t(self, t: float) -> "ConfigReader": t=t, tensor_name=self.tensor_name, slice_out=self.slice_out, - slices_in=self.slices_in, ) def parameter( @@ -172,16 +159,16 @@ def parameter( default: Any = None, required: bool = False, ) -> Any: - if model and self.slices_in: - for s in self.slices_in: - if s.model == str(model) and s.parameters and name in s.parameters: - value = evaluate_setting( - self.tensor_name, s.parameters[name], self.t - ) - if value is not None: - return value - if self.slice_out: + if model: + for s in self.slice_out.sources: + if s.model == str(model) and s.parameters and name in s.parameters: + value = evaluate_setting( + self.tensor_name, s.parameters[name], self.t + ) + if value is not None: + return value + if self.slice_out.parameters and name in self.slice_out.parameters: value = evaluate_setting( self.tensor_name, self.slice_out.parameters[name], self.t @@ -216,5 +203,14 @@ def parameter( path_paths = [str(s) for s in [model, self.tensor_name] if s] p = ".".join(path_paths) suffix = f" for {p}" if p else "" + print(f"name: {name}") + print(f"model: {model}") + print(f"slice_out: {self.slice_out}") + for s in self.slice_out.sources: + print(repr(s)) + print(s.model == str(model)) + print(bool(s.parameters)) + print(name in s.parameters) + print(repr(self)) raise RuntimeError(f"Missing required parameter {name}{suffix}") return default diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index 12d7a51a..dee79731 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -14,11 +14,11 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from pydantic import BaseModel -from mergekit.common import ModelReference +from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.tasks import GatherTensors @@ -42,8 +42,8 @@ def make_task( *, output_tensor_name: str, tensors: GatherTensors, - parameters: Dict[str, Any], - tensor_parameters: Dict[ModelReference, Dict[str, Any]], + parameters: ImmutableMap[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], base_model: Optional[ModelReference], ) -> Task: ... diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 4ca27f6d..ea1fb42f 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -33,7 +33,7 @@ class ConsensusMethod(str, Enum): sum = "sum" -class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel): +class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): consensus_method: Optional[ConsensusMethod] sparsification_method: Optional[SparsificationMethod] default_normalize: bool @@ -57,14 +57,14 @@ def make_task( output_tensor_name: str, tensors: GatherTensors, base_model: Optional[ModelReference], - parameters: Dict[str, Any], - tensor_parameters: Dict[ModelReference, Dict[str, Any]], + parameters: ImmutableMap[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Task: return GTATask( method=self, tensors=tensors, base_model=base_model, - tensor_parameters=ImmutableMap(data=tensor_parameters), + tensor_parameters=tensor_parameters, int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], out_tensor_name=output_tensor_name, @@ -141,8 +141,8 @@ def execute( def get_task_vectors( parameter_name: str, base_model: ModelReference, - tensors: Dict[ModelReference, torch.Tensor], - tensor_parameters: Dict[ModelReference, Dict[str, Any]], + tensors: ImmutableMap[ModelReference, torch.Tensor], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Tuple[List[Dict[str, Any]], torch.Tensor]: keys = list(tensors.keys()) base = tensors[base_model] diff --git a/mergekit/merge_methods/linear.py b/mergekit/merge_methods/linear.py index 1689ebd7..96a2e776 100644 --- a/mergekit/merge_methods/linear.py +++ b/mergekit/merge_methods/linear.py @@ -26,7 +26,7 @@ class LinearMergeTask(Task[torch.Tensor]): gather_tensors: GatherTensors - tensor_parameters: ImmutableMap[ModelReference, Dict[str, Any]] + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] normalize: bool parameter_name: str @@ -74,12 +74,12 @@ def make_task( output_tensor_name: str, tensors: GatherTensors, parameters: Dict[str, Any], - tensor_parameters: Dict[ModelReference, Dict[str, Any]], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], **_kwargs, ) -> Task: return LinearMergeTask( gather_tensors=tensors, - tensor_parameters=ImmutableMap(data=tensor_parameters), + tensor_parameters=tensor_parameters, normalize=parameters["normalize"], parameter_name=output_tensor_name, ) diff --git a/mergekit/merge_methods/passthrough.py b/mergekit/merge_methods/passthrough.py index 41208e1e..1165e484 100644 --- a/mergekit/merge_methods/passthrough.py +++ b/mergekit/merge_methods/passthrough.py @@ -18,7 +18,7 @@ import torch from torch._tensor import Tensor -from mergekit.common import ModelReference +from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod from mergekit.tasks import GatherTensors @@ -50,7 +50,7 @@ def make_task( self, *, tensors: GatherTensors, - parameters: Dict[str, Any], + parameters: ImmutableMap[str, Any], **kwargs, ) -> Task: return PassthroughMergeTask(gather_tensors=tensors, scale=parameters["scale"]) diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index c6fbe8bf..907eed89 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -19,7 +19,7 @@ import torch from torch._tensor import Tensor -from mergekit.common import ModelReference, rectify_embed_sizes +from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes from mergekit.graph import Task from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod from mergekit.tasks import GatherTensors @@ -34,13 +34,15 @@ class SlerpTask(Task[torch.Tensor]): def arguments(self) -> Dict[str, Task]: return {"tensors": self.gather_tensors} - def execute(self, input_tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: - if len(input_tensors) == 1: - return list(input_tensors.values())[0] - elif len(input_tensors) != 2: + def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: + if len(tensors) == 1: + return list(tensors.values())[0] + elif len(tensors) != 2: raise RuntimeError("Slerp merge expects exactly two models") + elif self.base_model not in tensors: + raise RuntimeError("Base model not in input tensors") - [a, b] = list(input_tensors.items()) + [a, b] = list(tensors.items()) if a[0] != self.base_model: [a, b] = [b, a] prepped_tensors = [a[1], b[1]] @@ -67,7 +69,7 @@ def make_task( *, output_tensor_name: str, tensors: GatherTensors, - parameters: Dict[str, Any], + parameters: ImmutableMap[str, Any], base_model: ModelReference | None, **_kwargs, ) -> Task: diff --git a/mergekit/merge_methods/tokenizer_permute.py b/mergekit/merge_methods/tokenizer_permute.py index c5bca95a..d4b1ffc7 100644 --- a/mergekit/merge_methods/tokenizer_permute.py +++ b/mergekit/merge_methods/tokenizer_permute.py @@ -126,7 +126,7 @@ def make_task( *, tensors: GatherTensors, parameters: Dict[str, Any], - tensor_parameters: Dict[ModelReference, Dict[str, Any]], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], base_model: Optional[ModelReference], **_kwargs, ) -> Task: @@ -136,5 +136,5 @@ def make_task( gather_tensors=tensors, use_slerp=parameters["embed_slerp"], slerp_t=parameters["t"], - tensor_parameters=ImmutableMap(data=tensor_parameters), + tensor_parameters=tensor_parameters, ) diff --git a/mergekit/plan.py b/mergekit/plan.py index 865d8722..8a755c6e 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -13,6 +13,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. +import logging from typing import List, Optional from mergekit import merge_methods @@ -35,6 +36,7 @@ class MergePlanner: config: MergeConfiguration arch_info: ArchitectureInfo clone_tensors: bool + trust_remote_code: bool _writer_task: TensorWriterTask _method: MergeMethod _tasks: List[Task] = [] @@ -51,6 +53,7 @@ def __init__( self.config = config self.arch_info = arch_info self.clone_tensors = options.clone_tensors + self.trust_remote_code = options.trust_remote_code self._method = merge_methods.get(config.merge_method) self._writer_task = TensorWriterTask( out_path=out_path, max_shard_size=options.out_shard_size @@ -68,6 +71,53 @@ def __init__( trust_remote_code=options.trust_remote_code, ) + def normalize_config(self): + base_model = ( + ModelReference.parse(self.config.base_model) + if self.config.base_model + else None + ) + + # if models to merge are specified instead of output slices, compute them + if self.config.models: + if self.config.slices: + raise RuntimeError( + "Must specify either models to merge or output slices" + ) + + slices_in = [] + base_included = False + + for model_in in self.config.models: + mref = ModelReference.parse(model_in.model) + + if base_model and mref == base_model: + base_included = True + + model_cfg = mref.config(trust_remote_code=self.trust_remote_code) + num_layers = self.arch_info.num_layers(model_cfg) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=model_in.model, + parameters=model_in.parameters, + ) + ) + + if base_model and not base_included: + logging.info("Base model specified but not in input models - adding") + base_cfg = base_model.config(trust_remote_code=self.trust_remote_code) + num_layers = self.arch_info.num_layers(base_cfg) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=str(base_model), + ) + ) + + self.config.slices = [OutputSliceDefinition(sources=slices_in)] + self.config.models = None + def plan_tensor( self, name: str, @@ -82,7 +132,7 @@ def plan_tensor( tokenizer_task=self._tokenizer_task ) - cfg_g = cfg_reader.for_in_slices(None).for_tensor(name) + cfg_g = cfg_reader.for_tensor(name) global_params = {} for p in tensor_merge_method.parameters(): global_params[p.name] = cfg_g.parameter( @@ -91,11 +141,15 @@ def plan_tensor( tensor_params = {} for model, name_in in zip(models, names_in): + is_base = str(model) == cfg_reader.config.base_model tensor_params[model] = {} cfg_m = cfg_reader.for_tensor(name_in) for p in tensor_merge_method.tensor_parameters(): tensor_params[model][p.name] = cfg_m.parameter( - p.name, model=model, required=p.required, default=p.default_value + p.name, + model=model, + required=p.required and not is_base, + default=p.default_value, ) gather_tensors = GatherTensors( @@ -111,8 +165,12 @@ def plan_tensor( tensor_task = tensor_merge_method.make_task( output_tensor_name=name, tensors=gather_tensors, - parameters=global_params, - tensor_parameters=tensor_params, + parameters=ImmutableMap(data=global_params), + tensor_parameters=ImmutableMap( + data={ + key: ImmutableMap(data=tensor_params[key]) for key in tensor_params + } + ), base_model=base_model, ) save_task = SaveTensor( @@ -165,10 +223,11 @@ def plan_slice(self, definition: OutputSliceDefinition): definition.sources, layer_offset=idx, t=t, - cfg_reader=cfg_reader.for_in_slices(definition.sources), + cfg_reader=cfg_reader, ) def plan(self): + self.normalize_config() self._tasks = [] for weight_name in self.arch_info.pre_weights(): @@ -176,7 +235,11 @@ def plan(self): weight_name, [weight_name] * len(self.config.slices[0].sources), [ModelReference.parse(s.model) for s in self.config.slices[0].sources], - ConfigReader(config=self.config, t=0, tensor_name=weight_name), + ConfigReader( + config=self.config, + t=0, + tensor_name=weight_name, + ).for_out_slice(self.config.slices[0]), ) for out_slice in self.config.slices: @@ -187,7 +250,11 @@ def plan(self): weight_name, [weight_name] * len(self.config.slices[-1].sources), [ModelReference.parse(s.model) for s in self.config.slices[-1].sources], - ConfigReader(config=self.config, t=1, tensor_name=weight_name), + ConfigReader( + config=self.config, + t=1, + tensor_name=weight_name, + ).for_out_slice(self.config.slices[-1]), ) self._tasks.append( diff --git a/mergekit/tasks.py b/mergekit/tasks.py index 02e4a374..b4a71866 100644 --- a/mergekit/tasks.py +++ b/mergekit/tasks.py @@ -64,6 +64,8 @@ def execute(self) -> torch.Tensor: def group_label(self) -> Optional[str]: loader = LoaderCache().get(self.model) + if self.tensor not in loader.index.tensor_paths: + print(loader.index.tensor_paths) shard_path = loader.index.tensor_paths[self.tensor] return _normalized_shard_name(shard_path) diff --git a/tests/test_merges.py b/tests/test_merges.py new file mode 100644 index 00000000..cf2fae5b --- /dev/null +++ b/tests/test_merges.py @@ -0,0 +1,108 @@ +import tempfile +from typing import Optional + +import pytest +from transformers import LlamaConfig, LlamaForCausalLM + +from mergekit.common import MergeOptions +from mergekit.config import ( + InputModelDefinition, + InputSliceDefinition, + MergeConfiguration, + OutputSliceDefinition, +) +from mergekit.merge import run_merge + + +def make_picollama(path: str): + cfg = LlamaConfig( + vocab_size=64, + hidden_size=128, + intermediate_size=128, + num_attention_heads=16, + num_hidden_layers=1, + ) + model = LlamaForCausalLM(cfg) + model.save_pretrained(path, safe_serialization=True) + return str(path) + + +@pytest.fixture(scope="session") +def model_a(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_a")) + + +@pytest.fixture(scope="session") +def model_b(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_b")) + + +@pytest.fixture(scope="session") +def model_c(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_c")) + + +class TestMerges: + def test_gpt2_copy(self): + config = MergeConfiguration( + merge_method="passthrough", + models=[InputModelDefinition(model="gpt2")], + dtype="bfloat16", + ) + with tempfile.TemporaryDirectory() as tmpdir: + run_merge(config, out_path=tmpdir, options=MergeOptions()) + + def test_gpt2_stack(self): + config = MergeConfiguration( + merge_method="passthrough", + slices=[ + OutputSliceDefinition( + sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])] + * 2 + ) + ], + dtype="bfloat16", + ) + with tempfile.TemporaryDirectory() as tmpdir: + run_merge(config, out_path=tmpdir, options=MergeOptions()) + + def test_linear_merge(self, model_a, model_b): + config = self.two_model_config(model_a, model_b, merge_method="linear") + with tempfile.TemporaryDirectory() as tmpdir: + run_merge(config, out_path=tmpdir, options=MergeOptions()) + + def test_slerp_merge(self, model_a, model_b): + config = self.two_model_config( + model_a, model_b, merge_method="slerp", base_model=model_a + ) + config.parameters = {"t": 0.35} + with tempfile.TemporaryDirectory() as tmpdir: + run_merge(config, out_path=tmpdir, options=MergeOptions()) + + def test_task_arithmetic_merge(self, model_a, model_b, model_c): + config = self.two_model_config( + model_a, model_b, merge_method="task_arithmetic", base_model=model_c + ) + with tempfile.TemporaryDirectory() as tmpdir: + run_merge(config, out_path=tmpdir, options=MergeOptions()) + + def two_model_config( + self, model_a, model_b, merge_method: str, base_model: Optional[str] = None + ): + config = MergeConfiguration( + merge_method=merge_method, + base_model=base_model, + models=[ + InputModelDefinition( + model=model_a, + parameters={"weight": 0.6}, + ), + InputModelDefinition( + model=model_b, + parameters={"weight": 0.4}, + ), + ], + dtype="bfloat16", + ) + + return config