Skip to content

Commit

Permalink
Cleanup and add flag for model card
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 2, 2024
1 parent c204fbd commit ae58a72
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
19 changes: 8 additions & 11 deletions mergekit/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,26 @@ def generate_card(config: MergeConfiguration, name: Optional[str] = None) -> str
if not name:
name = "Untitled Model (1)"

hf_bases = list(extract_hf_paths(config.referenced_models()))
tags = ["mergekit", "merge"]

actual_base = ModelReference.parse(config.base_model) if config.base_model else None
if config.merge_method == "slerp":
# curse my past self
actual_base = None

base_text = ""
if actual_base:
models = set(config.referenced_models()).difference({actual_base})
base_list = [actual_base] + list(models)
else:
base_list = config.referenced_models()

hf_bases = list(extract_hf_paths(base_list))
tags = ["mergekit", "merge"]
base_text = f" using {modelref_md(actual_base)} as a base"

model_bullets = []
for model in base_list:
for model in config.referenced_models():
if model == actual_base:
# actual_base is mentioned in base_text - don't include in list
continue

model_bullets.append("* " + modelref_md(model))

base_text = ""
if actual_base:
base_text = f" using {modelref_md(actual_base)} as a base"
return CARD_TEMPLATE.format(
metadata=yaml.dump({"base_model": hf_bases, "tags": tags}),
model_list="\n".join(model_bullets),
Expand Down
4 changes: 4 additions & 0 deletions mergekit/scripts/run_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def main(
lazy_unpickle: Annotated[
bool, typer.Option(help="Experimental lazy unpickler for lower memory usage")
] = False,
write_model_card: Annotated[
bool, typer.Option(help="Output README.md for huggingface hub")
] = True,
):
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

Expand All @@ -96,6 +99,7 @@ def main(
trust_remote_code=trust_remote_code,
clone_tensors=clone_tensors,
lazy_unpickle=lazy_unpickle,
write_model_card=write_model_card,
),
)

Expand Down

0 comments on commit ae58a72

Please sign in to comment.