Skip to content

Commit

Permalink
Freeze "output" embedding when using tied embeddings. (#543)
Browse files Browse the repository at this point in the history
* fix

* dont freeze if non frozen

---------

Co-authored-by: cat-state <cat@meow>
  • Loading branch information
cat-state and cat-state committed Aug 2, 2023
1 parent 1d69f90 commit 5735bf3
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@ def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequenti
)


def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0):
def freeze_bottom_causal_layers(model: transformers.PreTrainedModel, num_layers_unfrozen: int = 0):
"""Freezes the bottom transformer block layers of the specified model."""
hidden_layers = hf_get_decoder_blocks(model)

if num_layers_unfrozen == 0:
hidden_layers_to_freeze = list(hidden_layers)
hidden_layers_to_freeze += [model.get_input_embeddings(), model.get_output_embeddings()]
elif num_layers_unfrozen > 0:
hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen]
hidden_layers_to_freeze += [model.get_input_embeddings()]
if model.config.tie_word_embeddings:
hidden_layers_to_freeze += [model.get_output_embeddings()]
else:
hidden_layers_to_freeze = []

for layer in hidden_layers_to_freeze:
layer.requires_grad_(False)

Expand Down

0 comments on commit 5735bf3

Please sign in to comment.