diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index df9a2ef57..65d2c6ecf 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -270,8 +270,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): _auto_model_parent_class = transformers.AutoModelForCausalLM _supported_modules = ["v_head"] - _supported_args = ["peft_config"] - _supported_args = ["num_value_layers_unfrozen"] + _supported_args = ["peft_config", "num_value_layers_unfrozen"] def __init__( self, @@ -329,6 +328,7 @@ def forward( else: outputs = self.base_model(**forward_kwargs) + # TODO: Apply PEFT to value branch if self.num_value_layers_unfrozen > 0: output_shape = outputs.hidden_states[-1].size() forward_kwargs.pop("input_ids", None)