Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ordering of ppo epoch iteration #522

Merged
merged 6 commits into from
Jul 31, 2023
Merged

Conversation

RobertKirk
Copy link
Contributor

@RobertKirk RobertKirk commented Jul 13, 2023

Suppose you have 256 rollouts, and you batch them into batches B1, B2, B3, B4 (each of size 64). The order of gradient updates (assuming 3 ppo_epochs) is:

trlx: B1 B1 B1 B2 B2 B2 B3 B3 B3 B4 B4 B4

However, what we should actually be doing (and what alpaca-farm and other rlhf implementations, and standard implementations of PPO do), is

optimal: B1 B2 B3 B4 B1' B2' B3' B4' B1* B2* B3* B4*

This change reorders the learning to make the code use the optimal ordering above. It also renames n_updates_per_batch to n_inner_epochs as that's a more accurate description (especially now), adjusts forward_time and backward_time to not type-error, and renames mbs and mb to minibatch and microbatch (as that's what they are).

Suppose you have 256 rollouts, and you batch them into batches B1, B2,
B3, B4 (each of size 64). The order of gradient updates (assuming 3
ppo_epochs) is:

`trlx: B1 B1 B1 B2 B2 B2 B3 B3 B3 B4 B4 B4`

However, what we should actually be doing (and what alpaca-farm and
other rlhf implementations, and standard implementations of PPO do), is

`improved: B1 B2 B3 B4 B1 B2 B3 B4 B1 B2 B3 B4`

It would be even better if we actually produced new random batches at
each ppo_epoch, that would require more refactoring. i.e.:

`optimal: B1 B2 B3 B4 B1' B2' B3' B4' B1* B2* B3* B4*`

This change basically just reorders the learning to make the code use
the `improved` ordering above. It also renames n_updates_per_batch to
n_inner_epochs as that's a more accurate description (especially now),
adjusts forward_time and backward_time to not type-error, and renames
mbs and mb to minibatch and microbatch (as that's what they are).
Copy link
Collaborator

@maxreciprocate maxreciprocate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splendid! Thank you, for this implementation detail has been left out. Here are comparisions between runs using this PR and the main branch https://wandb.ai/sorry/trlx-references/reports/main-v-main--Vmlldzo0ODc2MTAx (forgive the ambiguity with naming branches main and main, but respective top runs in wandb are made on this PR). Difference isn't too notable on most of these, however there is a difference for ILQL runs, but since self.n_inner_epochs=1 for it, this change shouldn't bear an effect there I suppose

trlx/trainer/accelerate_base_trainer.py Outdated Show resolved Hide resolved
@RobertKirk

This comment was marked as resolved.

@maxreciprocate
Copy link
Collaborator

Yeah you're right, that's what I meant by "on most of these", but also I rather meant that there are no obvious deteriorations, it was not in demeaning way towards this change, since it's obviously correct and in line with other implementations

(Top ones are from this PR, terribly sorry for the confusion)

This way we get better shuffling. Note that we now pass shuffle=True
(implicitly) in the ppo trainer, whereas before we had shuffle=False.
Shuffling is better here, as it means the gradient estimation over
minibatches is less correlated.
@Dahoas
Copy link
Collaborator

Dahoas commented Jul 20, 2023

@maxreciprocate Do we want to merge this now?

Copy link
Collaborator

@maxreciprocate maxreciprocate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although while there isn't difference for SFT runs, there is a difference for ILQL runs, however I can't figure out why exactly, since data ordering stays unchanged, as it should

https://wandb.ai/sorry/trlx-references/reports/main-v-main--Vmlldzo0OTQ2MzEw

def create_train_dataloader(self, shuffle=True, accelerate_prepare=True):
dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=shuffle)
if accelerate_prepare:
dataloader = self.accelerator.prepare_dataloader(dataloader)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that accelerate_prepare argument is redundant here, but also maybe the whole function implementation is, since it's overwritten individually for each trainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this an abstract function or a not-implemented function if it's ok for AccelerateRLTrainer to be an abstract class or require subclassing, but I didn't know whether that was the case, so I put a standard implementation here instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's make it a abstract method just like the ones below it, with only difference between different implementations is that we either shard or not shard dataloader, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is whether the dataloader is prepared by accelerate and whether you pass shuffle=True - ppo doesn't prepare and passess shuffle=True, ILQL and SFT do prepare and don't pass shuffle=True. I've made the change now.

@RobertKirk
Copy link
Contributor Author

It could be seeded differently as we pass the dataloaders to accelerate prepare in a different order than we did before? Or there's uncontrolled non-determinism that the seed isn't setting?

@maxreciprocate
Copy link
Collaborator

Probably, that's most likely the reason, since I've check the order in which tensors come for ILQL and it hasn't changed

Copy link
Collaborator

@maxreciprocate maxreciprocate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent fix, greatly appreciated @RobertKirk!

@maxreciprocate maxreciprocate merged commit 6f7f59d into CarperAI:main Jul 31, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants