Skip to content

Commit

Permalink
Specify chat template for output model (#367)
Browse files Browse the repository at this point in the history
Adds a `chat_template` field to merge configs, which can either be a
Jinja template string or one of `chatml`, `llama3`, `alpaca`, `mistral`.
Also supports `auto` which will try to select the most common template
among the input models.
  • Loading branch information
cg123 authored Jul 16, 2024
1 parent aa0399f commit 5fa7782
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 16 deletions.
Empty file.
29 changes: 29 additions & 0 deletions mergekit/_data/chat_templates/alpaca.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}
2 changes: 2 additions & 0 deletions mergekit/_data/chat_templates/chatml.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}
{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
7 changes: 7 additions & 0 deletions mergekit/_data/chat_templates/llama3.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% set loop_messages = messages %}
{% for message in loop_messages %}
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}
{{ content }}
{% endfor %}
{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}
24 changes: 24 additions & 0 deletions mergekit/_data/chat_templates/mistral.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}

{{- bos_token }}
{%- for message in loop_messages %}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif %}
{%- if message['role'] == 'user' %}
{%- if loop.first and system_message is defined %}
{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}
{%- else %}
{{- ' [INST] ' + message['content'] + ' [/INST]' }}
{%- endif %}
{%- elif message['role'] == 'assistant' %}
{{- ' ' + message['content'] + eos_token}}
{%- else %}
{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}
{%- endif %}
{%- endfor %}
1 change: 1 addition & 0 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class MergeConfiguration(BaseModel):
Literal["union"], Literal["base"], ModelReference, None
] = None
tokenizer: Optional[TokenizerConfig] = None
chat_template: Optional[str] = None
out_dtype: Optional[str] = None

def referenced_models(self) -> List[ModelReference]:
Expand Down
88 changes: 74 additions & 14 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import importlib
import importlib.resources
import logging
import os
import shutil
from collections import Counter
from typing import Optional

import tqdm
import transformers

from mergekit._data import chat_templates
from mergekit.architecture import ArchitectureInfo, get_architecture_info
from mergekit.card import generate_card
from mergekit.config import MergeConfiguration
Expand Down Expand Up @@ -116,32 +120,87 @@ def run_merge(
) as fp:
fp.write(config_source)

if tokenizer is None and options.copy_tokenizer:
try:
_copy_tokenizer(
merge_config, out_path, trust_remote_code=options.trust_remote_code
)
except Exception as e:
logging.error(
"Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.",
exc_info=e,
if tokenizer is None:
if options.copy_tokenizer:
try:
_copy_tokenizer(
merge_config, out_path, trust_remote_code=options.trust_remote_code
)
except Exception as e:
logging.error(
"Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.",
exc_info=e,
)
elif merge_config.chat_template:
logging.warning(
"Chat template specified but no tokenizer found. Chat template will not be saved."
)

if tokenizer:
logging.info("Saving tokenizer")
_set_chat_template(tokenizer, merge_config)
tokenizer.save_pretrained(out_path, safe_serialization=True)


def _set_chat_template(
tokenizer: transformers.PreTrainedTokenizerBase,
merge_config: MergeConfiguration,
trust_remote_code: bool = False,
):
chat_template = merge_config.chat_template
if not chat_template:
return

if chat_template == "auto":
# see if there is a plurality chat template among the input models
model_templates = []
for model in merge_config.referenced_models():
try:
tok = transformers.AutoTokenizer.from_pretrained(
model.model.path,
revision=model.model.revision,
trust_remote_code=trust_remote_code,
)
template = tok.chat_template
if isinstance(template, dict):
template = template.get("default", None)
if template:
model_templates.append(template.strip())
except Exception as e:
logging.warning(f"Unable to load tokenizer for {model}", exc_info=e)

if not model_templates:
return

chat_template = Counter(model_templates).most_common(1)[0][0]
logging.info(f"Auto-selected chat template: {chat_template}")

elif importlib.resources.is_resource(chat_templates, chat_template + ".jinja"):
with importlib.resources.open_text(
chat_templates, chat_template + ".jinja"
) as fp:
chat_template = fp.read()

elif len(chat_template) < 20 or "{" not in chat_template:
raise RuntimeError(f"Invalid chat template: {chat_template}")

tokenizer.chat_template = chat_template


def _copy_tokenizer(
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
):
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])

if os.path.exists(
os.path.join(donor_model.model.path, "tokenizer_config.json")
) and (
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json"))
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model"))
if (
(not merge_config.chat_template)
and os.path.exists(
os.path.join(donor_model.model.path, "tokenizer_config.json")
)
and (
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json"))
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model"))
)
):
logging.info(f"Copying tokenizer from {donor_model}")

Expand All @@ -166,6 +225,7 @@ def _copy_tokenizer(
revision=donor_model.model.revision,
trust_remote_code=trust_remote_code,
)
_set_chat_template(tokenizer, merge_config)
tokenizer.save_pretrained(out_path, safe_serialization=True)


Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ packages = [
"mergekit.tokenizer",
"mergekit._data",
"mergekit._data.architectures",
"mergekit._data.chat_templates",
]
include-package-data = true
package-data = { "mergekit._data.architectures" = ["*.json"] }
package-data = { "mergekit._data.architectures" = [
"*.json",
], "mergekit._data.chat_templates" = [
"*.jinja",
] }

[tool.isort]
profile = "black"
Expand All @@ -74,6 +79,6 @@ minversion = "6.0"
filterwarnings = [
"ignore::pydantic.PydanticDeprecatedSince20:huggingface_hub.*:",
"ignore::FutureWarning:huggingface_hub.*:",
"ignore:(read_text|open_text|contents) is deprecated:DeprecationWarning", # yes i know, but files() doesn't exist in 3.8
"ignore:(read_text|open_text|contents|is_resource) is deprecated:DeprecationWarning", # yes i know, but files() doesn't exist in 3.8
]
testpaths = ["tests"]
52 changes: 52 additions & 0 deletions tests/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Optional

from common import run_and_check_merge
from test_basic_merges import model_b
from test_tokenizer import model_base
from transformers import AutoTokenizer

from mergekit.config import InputModelDefinition, MergeConfiguration


def check_chat_template(model_path: str, needle: Optional[str] = None):
tokenizer = AutoTokenizer.from_pretrained(model_path)
if needle is None:
assert not tokenizer.chat_template, "Expected no chat template"
return
assert (
tokenizer.chat_template and needle in tokenizer.chat_template
), f"Expected chat template to contain {needle}"


class TestChatTemplate:
def test_template_chatml(self, model_base, model_b):
config = MergeConfiguration(
merge_method="linear",
models=[
InputModelDefinition(model=model_base, parameters={"weight": 0.5}),
InputModelDefinition(model=model_b, parameters={"weight": 0.5}),
],
base_model=model_base,
dtype="bfloat16",
chat_template="chatml",
)
run_and_check_merge(
config,
validate=lambda p: check_chat_template(p, "<|im_start|>"),
)

def test_template_literal_jinja(self, model_base, model_b):
config = MergeConfiguration(
merge_method="linear",
models=[
InputModelDefinition(model=model_base, parameters={"weight": 0.5}),
InputModelDefinition(model=model_b, parameters={"weight": 0.5}),
],
base_model=model_base,
dtype="bfloat16",
chat_template="{{messages[0]['content']}}",
)
run_and_check_merge(
config,
validate=lambda p: check_chat_template(p, "{{messages[0]['content']}}"),
)

0 comments on commit 5fa7782

Please sign in to comment.