Skip to content

Commit

Permalink
Add support for tie_word_embeddings when loading weights + support fo…
Browse files Browse the repository at this point in the history
…r SmolLM (#1508)
  • Loading branch information
TianyiQ committed Sep 25, 2024
1 parent fb2d068 commit 3c93187
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- BaiChuan2
- MiniCPM / MiniCPM 3
- XVERSE / XVERSE MoE
- SmolLM


**Embedding Models**
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))


Expand Down
1 change: 1 addition & 0 deletions test/srt/models/test_generation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ModelCase:
# All other models
ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"),
]

TORCH_DTYPES = [torch.float16]
Expand Down

0 comments on commit 3c93187

Please sign in to comment.