Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify chat template for output model #367

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
29 changes: 29 additions & 0 deletions mergekit/_data/chat_templates/alpaca.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}
2 changes: 2 additions & 0 deletions mergekit/_data/chat_templates/chatml.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}
{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
7 changes: 7 additions & 0 deletions mergekit/_data/chat_templates/llama3.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% set loop_messages = messages %}
{% for message in loop_messages %}
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}
{{ content }}
{% endfor %}
{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}
24 changes: 24 additions & 0 deletions mergekit/_data/chat_templates/mistral.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}

{{- bos_token }}
{%- for message in loop_messages %}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif %}
{%- if message['role'] == 'user' %}
{%- if loop.first and system_message is defined %}
{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}
{%- else %}
{{- ' [INST] ' + message['content'] + ' [/INST]' }}
{%- endif %}
{%- elif message['role'] == 'assistant' %}
{{- ' ' + message['content'] + eos_token}}
{%- else %}
{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}
{%- endif %}
{%- endfor %}
1 change: 1 addition & 0 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class MergeConfiguration(BaseModel):
Literal["union"], Literal["base"], ModelReference, None
] = None
tokenizer: Optional[TokenizerConfig] = None
chat_template: Optional[str] = None
out_dtype: Optional[str] = None

def referenced_models(self) -> List[ModelReference]:
Expand Down
88 changes: 74 additions & 14 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import importlib
import importlib.resources
import logging
import os
import shutil
from collections import Counter
from typing import Optional

import tqdm
import transformers

from mergekit._data import chat_templates
from mergekit.architecture import ArchitectureInfo, get_architecture_info
from mergekit.card import generate_card
from mergekit.config import MergeConfiguration
Expand Down Expand Up @@ -116,32 +120,87 @@ def run_merge(
) as fp:
fp.write(config_source)

if tokenizer is None and options.copy_tokenizer:
try:
_copy_tokenizer(
merge_config, out_path, trust_remote_code=options.trust_remote_code
)
except Exception as e:
logging.error(
"Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.",
exc_info=e,
if tokenizer is None:
if options.copy_tokenizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

try:
_copy_tokenizer(
merge_config, out_path, trust_remote_code=options.trust_remote_code
)
except Exception as e:
logging.error(
"Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.",
exc_info=e,
)
elif merge_config.chat_template:
logging.warning(
"Chat template specified but no tokenizer found. Chat template will not be saved."
)

if tokenizer:
logging.info("Saving tokenizer")
_set_chat_template(tokenizer, merge_config)
tokenizer.save_pretrained(out_path, safe_serialization=True)


def _set_chat_template(
tokenizer: transformers.PreTrainedTokenizerBase,
merge_config: MergeConfiguration,
trust_remote_code: bool = False,
):
chat_template = merge_config.chat_template
if not chat_template:
return

if chat_template == "auto":
# see if there is a plurality chat template among the input models
model_templates = []
for model in merge_config.referenced_models():
try:
tok = transformers.AutoTokenizer.from_pretrained(
model.model.path,
revision=model.model.revision,
trust_remote_code=trust_remote_code,
)
template = tok.chat_template
if isinstance(template, dict):
template = template.get("default", None)
if template:
model_templates.append(template.strip())
except Exception as e:
logging.warning(f"Unable to load tokenizer for {model}", exc_info=e)

if not model_templates:
return

chat_template = Counter(model_templates).most_common(1)[0][0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic can be made more robust later but this should be good for now.

logging.info(f"Auto-selected chat template: {chat_template}")

elif importlib.resources.is_resource(chat_templates, chat_template + ".jinja"):
with importlib.resources.open_text(
chat_templates, chat_template + ".jinja"
) as fp:
chat_template = fp.read()

elif len(chat_template) < 20 or "{" not in chat_template:
raise RuntimeError(f"Invalid chat template: {chat_template}")

tokenizer.chat_template = chat_template


def _copy_tokenizer(
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
):
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])

if os.path.exists(
os.path.join(donor_model.model.path, "tokenizer_config.json")
) and (
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json"))
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model"))
if (
(not merge_config.chat_template)
and os.path.exists(
os.path.join(donor_model.model.path, "tokenizer_config.json")
)
and (
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json"))
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model"))
)
):
logging.info(f"Copying tokenizer from {donor_model}")

Expand All @@ -166,6 +225,7 @@ def _copy_tokenizer(
revision=donor_model.model.revision,
trust_remote_code=trust_remote_code,
)
_set_chat_template(tokenizer, merge_config)
tokenizer.save_pretrained(out_path, safe_serialization=True)


Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ packages = [
"mergekit.tokenizer",
"mergekit._data",
"mergekit._data.architectures",
"mergekit._data.chat_templates",
]
include-package-data = true
package-data = { "mergekit._data.architectures" = ["*.json"] }
package-data = { "mergekit._data.architectures" = [
"*.json",
], "mergekit._data.chat_templates" = [
"*.jinja",
] }

[tool.isort]
profile = "black"
Expand All @@ -74,6 +79,6 @@ minversion = "6.0"
filterwarnings = [
"ignore::pydantic.PydanticDeprecatedSince20:huggingface_hub.*:",
"ignore::FutureWarning:huggingface_hub.*:",
"ignore:(read_text|open_text|contents) is deprecated:DeprecationWarning", # yes i know, but files() doesn't exist in 3.8
"ignore:(read_text|open_text|contents|is_resource) is deprecated:DeprecationWarning", # yes i know, but files() doesn't exist in 3.8
]
testpaths = ["tests"]
52 changes: 52 additions & 0 deletions tests/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Optional

from common import run_and_check_merge
from test_basic_merges import model_b
from test_tokenizer import model_base
from transformers import AutoTokenizer

from mergekit.config import InputModelDefinition, MergeConfiguration


def check_chat_template(model_path: str, needle: Optional[str] = None):
tokenizer = AutoTokenizer.from_pretrained(model_path)
if needle is None:
assert not tokenizer.chat_template, "Expected no chat template"
return
assert (
tokenizer.chat_template and needle in tokenizer.chat_template
), f"Expected chat template to contain {needle}"


class TestChatTemplate:
def test_template_chatml(self, model_base, model_b):
config = MergeConfiguration(
merge_method="linear",
models=[
InputModelDefinition(model=model_base, parameters={"weight": 0.5}),
InputModelDefinition(model=model_b, parameters={"weight": 0.5}),
],
base_model=model_base,
dtype="bfloat16",
chat_template="chatml",
)
run_and_check_merge(
config,
validate=lambda p: check_chat_template(p, "<|im_start|>"),
)

def test_template_literal_jinja(self, model_base, model_b):
config = MergeConfiguration(
merge_method="linear",
models=[
InputModelDefinition(model=model_base, parameters={"weight": 0.5}),
InputModelDefinition(model=model_b, parameters={"weight": 0.5}),
],
base_model=model_base,
dtype="bfloat16",
chat_template="{{messages[0]['content']}}",
)
run_and_check_merge(
config,
validate=lambda p: check_chat_template(p, "{{messages[0]['content']}}"),
)
Loading