Skip to content

Commit

Permalink
Fix: Remove _supported_args overwrite in AutoModeLForCausalLMWithValu…
Browse files Browse the repository at this point in the history
…eHead
  • Loading branch information
Dahoas committed Jul 26, 2023
1 parent 5683470 commit eb95a42
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):

_auto_model_parent_class = transformers.AutoModelForCausalLM
_supported_modules = ["v_head"]
_supported_args = ["peft_config"]
_supported_args = ["num_value_layers_unfrozen"]
_supported_args = ["peft_config", "num_value_layers_unfrozen"]

def __init__(
self,
Expand Down Expand Up @@ -329,6 +328,7 @@ def forward(
else:
outputs = self.base_model(**forward_kwargs)

# TODO: Apply PEFT to value branch
if self.num_value_layers_unfrozen > 0:
output_shape = outputs.hidden_states[-1].size()
forward_kwargs.pop("input_ids", None)
Expand Down

0 comments on commit eb95a42

Please sign in to comment.