Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering of ppo epoch iteration #522

Merged
merged 6 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,24 @@ def learn(self): # noqa: C901

# For each epoch
for _ in range(self.config.train.epochs):
# For each batch
for mbs in MiniBatchIterator(self.train_dataloader, self.mb_size, self.num_mb):
# For each update per batch
for _ in range(self.n_updates_per_batch):
# Note that whereas standard policy gradient methods perform one
# gradient update per batch, PPO for example commonly performs
# multiple gradient updates on the same batch of data.
# https://arxiv.org/pdf/1707.06347.pdf
forward_time = 0
backward_time = 0
# For each ppo epoch
for _ in range(self.n_inner_epochs):
# Note that whereas standard policy gradient methods perform one
# gradient update per batch, PPO for example commonly performs
# multiple epochs of gradient updates on the same batch of data.
# https://arxiv.org/pdf/1707.06347.pdf

# We create a new dataloader (so new data ordering and shuffle) each inner epoch
train_dataloader = self.create_train_dataloader()
# For each batch
for minibatch in MiniBatchIterator(train_dataloader, self.mb_size, self.num_mb):
forward_time = 0.0
backward_time = 0.0
stats_accum = []
for mb in mbs:
for microbatch in minibatch:
with self._accumulate():
forward_time -= time()
loss, stats = self.loss(mb)
loss, stats = self.loss(microbatch)
forward_time += time()
backward_time -= time()
self.accelerator.backward(loss)
Expand Down Expand Up @@ -633,6 +636,12 @@ def learn(self): # noqa: C901
self.post_epoch_callback()
tbar.close()

def create_train_dataloader(self, shuffle=True, accelerate_prepare=True):
dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=shuffle)
if accelerate_prepare:
dataloader = self.accelerator.prepare_dataloader(dataloader)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that accelerate_prepare argument is redundant here, but also maybe the whole function implementation is, since it's overwritten individually for each trainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this an abstract function or a not-implemented function if it's ok for AccelerateRLTrainer to be an abstract class or require subclassing, but I didn't know whether that was the case, so I put a standard implementation here instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's make it a abstract method just like the ones below it, with only difference between different implementations is that we either shard or not shard dataloader, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is whether the dataloader is prepared by accelerate and whether you pass shuffle=True - ppo doesn't prepare and passess shuffle=True, ILQL and SFT do prepare and don't pass shuffle=True. I've made the change now.

return dataloader

@abstractmethod
def get_arch(self, config: TRLConfig):
"""Returns a specific wrapper of the decoder architecture"""
Expand Down
10 changes: 6 additions & 4 deletions trlx/trainer/accelerate_ilql_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,20 @@ def loss(self, batch: Union[ILQLBatch, ILQLSeq2SeqBatch]):

return self.ilql.loss((logits, (qs, target_qs, vs)), batch)

def create_train_dataloader(self):
return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size))
maxreciprocate marked this conversation as resolved.
Show resolved Hide resolved

def prepare_learning(self):
train_dataloader = self.store.create_loader(self.config.train.batch_size)
self.train_dataloader = self.create_train_dataloader()
eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)

(
self.model,
self.opt,
self.train_dataloader,
self.eval_dataloader,
) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader)
) = self.accelerator.prepare(self.model, self.opt, eval_dataloader)

self.n_updates_per_batch = 1
self.n_inner_epochs = 1
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

Expand Down
9 changes: 6 additions & 3 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,19 @@ def post_epoch_callback(self):
def post_backward_callback(self):
self.kl_ctl.update(self.mean_kl, n_steps=self.config.train.batch_size)

def create_train_dataloader(self):
return self.store.create_loader(self.config.train.batch_size)

def prepare_learning(self):
eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size)
self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader)

self.make_experience(self.config.method.num_rollouts)

self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False)
self.train_dataloader = self.create_train_dataloader()

self.n_updates_per_batch = self.config.method.ppo_epochs
self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader)
self.n_inner_epochs = self.config.method.ppo_epochs
self.total_steps = self.config.train.epochs * self.n_inner_epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def add_prompt_pipeline(self, pipeline: PromptPipeline):
Expand Down
10 changes: 6 additions & 4 deletions trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,20 @@ def loss(self, batch):

return loss, stats

def create_train_dataloader(self):
return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size))

def prepare_learning(self):
train_dataloader = self.store.create_loader(self.config.train.batch_size)
self.train_dataloader = self.create_train_dataloader()
eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)

(
self.model,
self.opt,
self.train_dataloader,
self.eval_dataloader,
) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader)
) = self.accelerator.prepare(self.model, self.opt, eval_dataloader)

self.n_updates_per_batch = 1
self.n_inner_epochs = 1
self.total_steps = self.config.train.epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

Expand Down
Loading