Skip to content

Commit

Permalink
Add JAISLMHeadModel (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 authored Jan 13, 2024
1 parent deda4aa commit b1d84bf
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,31 @@ def num_layers(self, config: PretrainedConfig) -> int:
num_layers_key="n_layer",
)

JAIS_INFO = StaticTensorNames(
name="JAISLMHeadModel",
pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
embed_weight_names=["transformer.wte.weight"],
layer_prefix_format="transformer.h.{idx}",
layer_weight_suffixes=[
"attn.c_attn.weight",
"attn.c_attn.bias",
"attn.c_proj.weight",
"attn.c_proj.bias",
"ln_1.weight",
"ln_1.bias",
"ln_2.weight",
"ln_2.bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"mlp.c_fc2.weight",
"mlp.c_fc2.bias",
"mlp.c_proj.weight",
"mlp.c_proj.bias",
],
num_layers_key="n_layer",
)

GPT2_SEQCLASS_INFO = StaticTensorNames(
name="GPT2ForSequenceClassification",
pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
Expand Down Expand Up @@ -367,6 +392,7 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
GPT2_SEQCLASS_INFO,
CHATGLM_INFO,
STABLELM_INFO,
JAIS_INFO,
]
for arch in supported:
if arch.name == arch_name:
Expand Down

0 comments on commit b1d84bf

Please sign in to comment.