-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Specify chat template for output model #367
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d01feed
Add option to specify chat template for output model
cg123 53168fd
Add 'auto' option
cg123 f553ed7
Update pyproject.toml
cg123 87bc45d
Merge remote-tracking branch 'origin/main' into chat_template
cg123 b06edc9
Update pyproject.toml
cg123 ed1fe27
Remove debug spam, use logging
cg123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} | ||
|
||
{% for message in messages %} | ||
{% if message['role'] == 'user' %} | ||
### Instruction: | ||
{{ message['content']|trim -}} | ||
{% if not loop.last %} | ||
|
||
|
||
{% endif %} | ||
{% elif message['role'] == 'assistant' %} | ||
### Response: | ||
{{ message['content']|trim -}} | ||
{% if not loop.last %} | ||
|
||
|
||
{% endif %} | ||
{% elif message['role'] == 'user_context' %} | ||
### Input: | ||
{{ message['content']|trim -}} | ||
{% if not loop.last %} | ||
|
||
|
||
{% endif %} | ||
{% endif %} | ||
{% endfor %} | ||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} | ||
### Response: | ||
{% endif %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %} | ||
{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{% set loop_messages = messages %} | ||
{% for message in loop_messages %} | ||
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %} | ||
{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %} | ||
{{ content }} | ||
{% endfor %} | ||
{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{%- if messages[0]['role'] == 'system' %} | ||
{%- set system_message = messages[0]['content'] %} | ||
{%- set loop_messages = messages[1:] %} | ||
{%- else %} | ||
{%- set loop_messages = messages %} | ||
{%- endif %} | ||
|
||
{{- bos_token }} | ||
{%- for message in loop_messages %} | ||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} | ||
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }} | ||
{%- endif %} | ||
{%- if message['role'] == 'user' %} | ||
{%- if loop.first and system_message is defined %} | ||
{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }} | ||
{%- else %} | ||
{{- ' [INST] ' + message['content'] + ' [/INST]' }} | ||
{%- endif %} | ||
{%- elif message['role'] == 'assistant' %} | ||
{{- ' ' + message['content'] + eos_token}} | ||
{%- else %} | ||
{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }} | ||
{%- endif %} | ||
{%- endfor %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,14 +13,18 @@ | |
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
import importlib | ||
import importlib.resources | ||
import logging | ||
import os | ||
import shutil | ||
from collections import Counter | ||
from typing import Optional | ||
|
||
import tqdm | ||
import transformers | ||
|
||
from mergekit._data import chat_templates | ||
from mergekit.architecture import ArchitectureInfo, get_architecture_info | ||
from mergekit.card import generate_card | ||
from mergekit.config import MergeConfiguration | ||
|
@@ -116,32 +120,87 @@ def run_merge( | |
) as fp: | ||
fp.write(config_source) | ||
|
||
if tokenizer is None and options.copy_tokenizer: | ||
try: | ||
_copy_tokenizer( | ||
merge_config, out_path, trust_remote_code=options.trust_remote_code | ||
) | ||
except Exception as e: | ||
logging.error( | ||
"Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.", | ||
exc_info=e, | ||
if tokenizer is None: | ||
if options.copy_tokenizer: | ||
try: | ||
_copy_tokenizer( | ||
merge_config, out_path, trust_remote_code=options.trust_remote_code | ||
) | ||
except Exception as e: | ||
logging.error( | ||
"Failed to copy tokenizer. The merge was still successful, just copy it from somewhere else.", | ||
exc_info=e, | ||
) | ||
elif merge_config.chat_template: | ||
logging.warning( | ||
"Chat template specified but no tokenizer found. Chat template will not be saved." | ||
) | ||
|
||
if tokenizer: | ||
logging.info("Saving tokenizer") | ||
_set_chat_template(tokenizer, merge_config) | ||
tokenizer.save_pretrained(out_path, safe_serialization=True) | ||
|
||
|
||
def _set_chat_template( | ||
tokenizer: transformers.PreTrainedTokenizerBase, | ||
merge_config: MergeConfiguration, | ||
trust_remote_code: bool = False, | ||
): | ||
chat_template = merge_config.chat_template | ||
if not chat_template: | ||
return | ||
|
||
if chat_template == "auto": | ||
# see if there is a plurality chat template among the input models | ||
model_templates = [] | ||
for model in merge_config.referenced_models(): | ||
try: | ||
tok = transformers.AutoTokenizer.from_pretrained( | ||
model.model.path, | ||
revision=model.model.revision, | ||
trust_remote_code=trust_remote_code, | ||
) | ||
template = tok.chat_template | ||
if isinstance(template, dict): | ||
template = template.get("default", None) | ||
if template: | ||
model_templates.append(template.strip()) | ||
except Exception as e: | ||
logging.warning(f"Unable to load tokenizer for {model}", exc_info=e) | ||
|
||
if not model_templates: | ||
return | ||
|
||
chat_template = Counter(model_templates).most_common(1)[0][0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic can be made more robust later but this should be good for now. |
||
logging.info(f"Auto-selected chat template: {chat_template}") | ||
|
||
elif importlib.resources.is_resource(chat_templates, chat_template + ".jinja"): | ||
with importlib.resources.open_text( | ||
chat_templates, chat_template + ".jinja" | ||
) as fp: | ||
chat_template = fp.read() | ||
|
||
elif len(chat_template) < 20 or "{" not in chat_template: | ||
raise RuntimeError(f"Invalid chat template: {chat_template}") | ||
|
||
tokenizer.chat_template = chat_template | ||
|
||
|
||
def _copy_tokenizer( | ||
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False | ||
): | ||
donor_model = merge_config.base_model or (merge_config.referenced_models()[0]) | ||
|
||
if os.path.exists( | ||
os.path.join(donor_model.model.path, "tokenizer_config.json") | ||
) and ( | ||
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json")) | ||
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model")) | ||
if ( | ||
(not merge_config.chat_template) | ||
and os.path.exists( | ||
os.path.join(donor_model.model.path, "tokenizer_config.json") | ||
) | ||
and ( | ||
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json")) | ||
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model")) | ||
) | ||
): | ||
logging.info(f"Copying tokenizer from {donor_model}") | ||
|
||
|
@@ -166,6 +225,7 @@ def _copy_tokenizer( | |
revision=donor_model.model.revision, | ||
trust_remote_code=trust_remote_code, | ||
) | ||
_set_chat_template(tokenizer, merge_config) | ||
tokenizer.save_pretrained(out_path, safe_serialization=True) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Optional | ||
|
||
from common import run_and_check_merge | ||
from test_basic_merges import model_b | ||
from test_tokenizer import model_base | ||
from transformers import AutoTokenizer | ||
|
||
from mergekit.config import InputModelDefinition, MergeConfiguration | ||
|
||
|
||
def check_chat_template(model_path: str, needle: Optional[str] = None): | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
if needle is None: | ||
assert not tokenizer.chat_template, "Expected no chat template" | ||
return | ||
assert ( | ||
tokenizer.chat_template and needle in tokenizer.chat_template | ||
), f"Expected chat template to contain {needle}" | ||
|
||
|
||
class TestChatTemplate: | ||
def test_template_chatml(self, model_base, model_b): | ||
config = MergeConfiguration( | ||
merge_method="linear", | ||
models=[ | ||
InputModelDefinition(model=model_base, parameters={"weight": 0.5}), | ||
InputModelDefinition(model=model_b, parameters={"weight": 0.5}), | ||
], | ||
base_model=model_base, | ||
dtype="bfloat16", | ||
chat_template="chatml", | ||
) | ||
run_and_check_merge( | ||
config, | ||
validate=lambda p: check_chat_template(p, "<|im_start|>"), | ||
) | ||
|
||
def test_template_literal_jinja(self, model_base, model_b): | ||
config = MergeConfiguration( | ||
merge_method="linear", | ||
models=[ | ||
InputModelDefinition(model=model_base, parameters={"weight": 0.5}), | ||
InputModelDefinition(model=model_b, parameters={"weight": 0.5}), | ||
], | ||
base_model=model_base, | ||
dtype="bfloat16", | ||
chat_template="{{messages[0]['content']}}", | ||
) | ||
run_and_check_merge( | ||
config, | ||
validate=lambda p: check_chat_template(p, "{{messages[0]['content']}}"), | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥