Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move argument parsing to click #65

Merged
merged 3 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mergekit/io/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence
from typing import Dict, Optional, Sequence

import safetensors
import torch
Expand Down
2 changes: 1 addition & 1 deletion mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def finalize(self):
) as file:
json.dump(
{
"metadata": {"mergekit_version": "0.0.3.1"},
"metadata": {"mergekit_version": "0.0.3.2"},
"weight_map": self.weight_map,
},
file,
Expand Down
1 change: 0 additions & 1 deletion mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class MergeOptions(BaseModel):
low_cpu_memory: bool = False
out_shard_size: int = parse_kmb("5B")
copy_tokenizer: bool = True
allow_crimes: bool = False
clone_tensors: bool = False
trust_remote_code: bool = False
random_seed: Optional[int] = None
Expand Down
86 changes: 86 additions & 0 deletions mergekit/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import typing
from typing import Any, Callable, Optional, Union

import click
from click.core import Context, Parameter

from mergekit.common import parse_kmb
from mergekit.merge import MergeOptions

OPTION_HELP = {
"allow_crimes": "Allow mixing architectures",
"transformers_cache": "Override storage path for downloaded models",
"lora_merge_cache": "Path to store merged LORA models",
"cuda": "Perform matrix arithmetic on GPU",
"low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM",
"out_shard_size": "Number of parameters per output shard [default: 5B]",
"copy_tokenizer": "Copy a tokenizer to the output",
"clone_tensors": "Clone tensors before saving, to allow multiple occurrences of the same layer",
"trust_remote_code": "Trust remote code from huggingface repos (danger)",
"random_seed": "Seed for reproducible use of randomized merge methods",
"lazy_unpickle": "Experimental lazy unpickler for lower memory usage",
}


class ShardSizeParamType(click.ParamType):
name = "size"

def convert(
self, value: Any, param: Optional[Parameter], ctx: Optional[Context]
) -> int:
return parse_kmb(value)


def add_merge_options(f: Callable) -> Callable:
def wrapper(*args, **kwargs):
arg_dict = {}
for field_name in MergeOptions.model_fields:
if field_name in kwargs:
arg_dict[field_name] = kwargs.pop(field_name)

kwargs["merge_options"] = MergeOptions(**arg_dict)
f(*args, **kwargs)

for field_name, info in reversed(MergeOptions.model_fields.items()):
origin = typing.get_origin(info.annotation)
if origin is Union:
ty, prob_none = typing.get_args(info.annotation)
assert prob_none is type(None)
field_type = ty
else:
field_type = info.annotation

if field_name == "out_shard_size":
field_type = ShardSizeParamType()

arg_name = field_name.replace("_", "-")
if field_type == bool:
arg_str = f"--{arg_name}/--no-{arg_name}"
else:
arg_str = f"--{arg_name}"

help_str = OPTION_HELP.get(field_name, None)
wrapper = click.option(
arg_str,
type=field_type,
default=info.default,
help=help_str,
show_default=field_name != "out_shard_size",
)(wrapper)

return wrapper
29 changes: 15 additions & 14 deletions mergekit/scripts/bakllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

from typing import List, Optional

import typer
import click
import yaml
from pydantic import BaseModel
from typing_extensions import Annotated

from mergekit.config import (
ConditionalParameter,
Expand All @@ -41,16 +40,22 @@ class BakllamaConfig(BaseModel):
lm_head_source: Optional[str] = None


@click.command("bakllama")
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("out_path", type=str)
@click.option(
"--clone-tensors/--no-clone-tensors",
type=bool,
is_flag=True,
help="Clone tensors before saving, to allow multiple occurrences of the same layer",
default=False,
)
@click.option("--fp16/--no-fp16", type=bool, default=False)
def main(
config_path: str,
out_path: str,
clone_tensors: Annotated[
bool,
typer.Option(
help="Clone tensors before saving, to allow multiple occurrences of the same layer"
),
] = False,
fp16: bool = False,
clone_tensors: bool,
fp16: bool,
):
"""Wrapper for using legacy bakllama configuration files."""
with open(config_path, "r", encoding="utf-8") as file:
Expand All @@ -75,9 +80,5 @@ def main(
run_merge(merge_config, out_path, MergeOptions(clone_tensors=clone_tensors))


def _main():
typer.run(main)


if __name__ == "__main__":
_main()
main()
103 changes: 48 additions & 55 deletions mergekit/scripts/layershuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

import random
from typing import List, Optional
from typing import List

import typer
import click
import yaml
from typing_extensions import Annotated

from mergekit.architecture import get_architecture_info
from mergekit.common import ModelReference
Expand All @@ -28,49 +27,51 @@
OutputSliceDefinition,
)
from mergekit.merge import MergeOptions, run_merge


from mergekit.options import add_merge_options


@click.command("mergekit-layershuffle")
@click.argument("out_path", type=str)
@click.option("--model", "-m", multiple=True, type=str, help="Add a model to the merge")
@click.option(
"--weight",
"-w",
multiple=True,
type=float,
default=[],
show_default=False,
help="Weighting for a model",
)
@click.option(
"--print-yaml/--no-print-yaml",
is_flag=True,
help="Print YAML merge config for resulting model",
)
@click.option(
"--write-yaml",
type=click.Path(writable=True),
help="Path to write YAML merge config to",
)
@click.option(
"--dry-run", is_flag=True, help="Generate a config but do not run the merge"
)
@click.option("--fp16/--no-fp16", is_flag=True, help="Use FP16 precision")
@click.option(
"--full-random/--no-full-random",
is_flag=True,
help="Randomize layer index as well as source model",
)
@add_merge_options
def main(
out_path: Annotated[
str, typer.Argument(help="Output path for merged model", metavar="PATH")
],
model: Annotated[
List[str], typer.Option(help="Add a model to the merge", metavar="MODEL")
],
weight: Annotated[
List[float],
typer.Option(
help="Weighting for a model",
default_factory=list,
show_default=False,
),
],
print_yaml: Annotated[
bool, typer.Option(help="Print YAML merge config for resulting model")
] = False,
write_yaml: Annotated[
Optional[str], typer.Option(help="Path to write YAML merge config to")
] = None,
dry_run: Annotated[
bool, typer.Option(help="Generate a config but do not run the merge")
] = False,
fp16: bool = False,
lora_merge_cache: Annotated[
Optional[str],
typer.Option(help="Path to store merged LORA models", metavar="PATH"),
] = None,
transformers_cache: Annotated[
Optional[str],
typer.Option(
help="Override storage path for downloaded models", metavar="PATH"
),
] = None,
copy_tokenizer: Annotated[
bool, typer.Option(help="Copy a tokenizer to the output")
] = True,
full_random: Annotated[
bool, typer.Option(help="Randomize layer index as well as source model")
] = False,
out_path: str,
model: List[str],
weight: List[float],
print_yaml: bool,
write_yaml: bool,
dry_run: bool,
fp16: bool,
full_random: bool,
merge_options: MergeOptions,
):
models = [ModelReference.parse(m) for m in model]

Expand Down Expand Up @@ -135,17 +136,9 @@ def main(
run_merge(
merge_config,
out_path,
MergeOptions(
lora_merge_cache=lora_merge_cache,
transformers_cache=transformers_cache,
copy_tokenizer=copy_tokenizer,
),
options=merge_options,
)


def _main():
typer.run(main)


if __name__ == "__main__":
_main()
main()
Loading