diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 047f1a9c..4eaab88e 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -206,6 +206,7 @@ def num_layers(self, config: PretrainedConfig) -> int: ], ) + CHATGLM_INFO = StaticTensorNames( name="ChatGLMModel", pre_weight_names=[ @@ -233,6 +234,30 @@ def num_layers(self, config: PretrainedConfig) -> int: ) +JAPANESE_STABLELM_INFO = StaticTensorNames( + name="JapaneseStableLMAlphaForCausalLM", + pre_weight_names=["transformer.embed_in.weight"], + post_weight_names=[ + "transformer.final_layer_norm.bias", + "transformer.final_layer_norm.weight", + "embed_out.weight", + ], + embed_weight_names=["transformer.embed_in.weight", "embed_out.weight"], + layer_prefix_format="transformer.layers.{idx}", + layer_weight_suffixes=[ + "attention.dense.weight", + "attention.query_key_value.weight", + # computed from config, should be safe to exclude + # "attention.rotary_emb.inv_freq", + # "attention.rotary_emb.scale", + "mlp.out_proj.weight", + "mlp.packed_input_proj.weight", + "post_attention_layernorm.bias", + "post_attention_layernorm.weight", + ], +) + + class PhiTensorNames(ArchitectureInfo): architecture_name: str = "MixFormerSequentialForCausalLM" @@ -331,6 +356,7 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames: CHATGLM_INFO, STABLELM_INFO, PHI2_INFO, + JAPANESE_STABLELM_INFO, ] for arch in supported: if arch.name == arch_name: