Skip to content

Commit

Permalink
Fix: Include num_value_layers_frozen arg in Seq2Seq model init
Browse files Browse the repository at this point in the history
  • Loading branch information
Dahoas committed Jul 26, 2023
1 parent eb95a42 commit d48c1a4
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,14 +1191,18 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):

_auto_model_parent_class = transformers.AutoModelForSeq2SeqLM
_supported_modules = ["v_head"]
_supported_args = ["peft_config"]
_supported_args = ["peft_config", "num_value_layers_unfrozen"]

def __init__(
self,
base_model: transformers.PreTrainedModel,
peft_config=None,
num_value_layers_unfrozen=0,
):
super().__init__(base_model, peft_config=peft_config)
#TODO: Support Seq2Seq value branching
if num_value_layers_unfrozen > 0:
raise NotImplementedError("Value branches unsupported for Seq2Seq architecture")
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)

def forward(
Expand Down Expand Up @@ -1299,16 +1303,17 @@ def post_init(self, state_dict):

class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead):
_supported_modules = ["v_head", "frozen_head"]
_supported_args = ["num_layers_unfrozen", "peft_config"]
_supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]

def __init__(
self,
base_model: transformers.PreTrainedModel,
*,
num_layers_unfrozen: int = -1,
peft_config=None,
num_value_layers_unfrozen: int = 0,
):
super().__init__(base_model, peft_config=peft_config)
super().__init__(base_model, peft_config=peft_config, num_value_layers_unfrozen=num_value_layers_unfrozen)
self.num_layers_unfrozen = num_layers_unfrozen

if self.num_layers_unfrozen > 0 and not self.peft_type:
Expand Down

0 comments on commit d48c1a4

Please sign in to comment.