Skip to content

Commit

Permalink
fix(modeling_ppo): load reference head under zero3 (#489)
Browse files Browse the repository at this point in the history
* fix(modeling_ppo): copy reference head from gathered parameters

* style: satisfy black

* fix(accelerate_ppo_trainer): pin `synced_gpus` under zero3

* fix(ppo_trainer): zero stage check on the newest accelerate version

---------

Co-authored-by: reciprocated <[email protected]>
  • Loading branch information
maxreciprocate and maxreciprocate committed Jul 24, 2023
1 parent dbdefd8 commit e36fe9d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
18 changes: 13 additions & 5 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import deepspeed
import numpy as np
import torch
import torch.nn as nn
import transformers
from torchtyping import TensorType
from transformers.modeling_outputs import ModelOutput
Expand Down Expand Up @@ -443,10 +443,18 @@ def __init__(
super().__init__(base_model.config)

# The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model
decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model))
self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:])
self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model))
self.lm_head = deepcopy(hf_get_lm_head(base_model))

decoder_blocks = hf_get_decoder_blocks(base_model)[-num_layers_unfrozen:]
final_norm = hf_get_decoder_final_norm(base_model)
lm_head = hf_get_lm_head(base_model)

with deepspeed.zero.GatheredParameters(
list(decoder_blocks.parameters()) + list(final_norm.parameters()) + list(lm_head.parameters()),
modifier_rank=None,
):
self.decoder_blocks = deepcopy(decoder_blocks)
self.final_norm = deepcopy(final_norm)
self.lm_head = deepcopy(lm_head)

self.hidden_size = hf_get_hidden_size(self.config)
self.model_parallel = False
Expand Down
2 changes: 2 additions & 0 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,9 @@ def learn(self): # noqa: C901
loss, stats = self.loss(mb)
forward_time += time()
backward_time -= time()
self.model.train()
self.accelerator.backward(loss)
self.model.eval()
backward_time += time()
stats_accum.append(stats)

Expand Down
1 change: 1 addition & 0 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self, config: TRLConfig, **kwargs):
use_cache=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3",
)
self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs}

Expand Down

0 comments on commit e36fe9d

Please sign in to comment.