From e36fe9d107af4f54a4dfaa28082e6faeef5734af Mon Sep 17 00:00:00 2001 From: Max <56548574+maxreciprocate@users.noreply.github.com> Date: Mon, 24 Jul 2023 14:27:49 +0300 Subject: [PATCH] fix(modeling_ppo): load reference head under zero3 (#489) * fix(modeling_ppo): copy reference head from gathered parameters * style: satisfy black * fix(accelerate_ppo_trainer): pin `synced_gpus` under zero3 * fix(ppo_trainer): zero stage check on the newest accelerate version --------- Co-authored-by: reciprocated <56548574+reciprocated@users.noreply.github.com> --- trlx/models/modeling_ppo.py | 18 +++++++++++++----- trlx/trainer/accelerate_base_trainer.py | 2 ++ trlx/trainer/accelerate_ppo_trainer.py | 1 + 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 82d3ec637..743a07cb2 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -5,9 +5,9 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union +import deepspeed import numpy as np import torch -import torch.nn as nn import transformers from torchtyping import TensorType from transformers.modeling_outputs import ModelOutput @@ -443,10 +443,18 @@ def __init__( super().__init__(base_model.config) # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model - decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model)) - self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:]) - self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model)) - self.lm_head = deepcopy(hf_get_lm_head(base_model)) + + decoder_blocks = hf_get_decoder_blocks(base_model)[-num_layers_unfrozen:] + final_norm = hf_get_decoder_final_norm(base_model) + lm_head = hf_get_lm_head(base_model) + + with deepspeed.zero.GatheredParameters( + list(decoder_blocks.parameters()) + list(final_norm.parameters()) + list(lm_head.parameters()), + modifier_rank=None, + ): + self.decoder_blocks = deepcopy(decoder_blocks) + self.final_norm = deepcopy(final_norm) + self.lm_head = deepcopy(lm_head) self.hidden_size = hf_get_hidden_size(self.config) self.model_parallel = False diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index d9b7ae8fc..b9faed065 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -568,7 +568,9 @@ def learn(self): # noqa: C901 loss, stats = self.loss(mb) forward_time += time() backward_time -= time() + self.model.train() self.accelerator.backward(loss) + self.model.eval() backward_time += time() stats_accum.append(stats) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 34ceb5e16..6a0b4ba71 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -90,6 +90,7 @@ def __init__(self, config: TRLConfig, **kwargs): use_cache=True, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, + synced_gpus=os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3", ) self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs}