-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ordering of ppo epoch iteration #522
Conversation
Suppose you have 256 rollouts, and you batch them into batches B1, B2, B3, B4 (each of size 64). The order of gradient updates (assuming 3 ppo_epochs) is: `trlx: B1 B1 B1 B2 B2 B2 B3 B3 B3 B4 B4 B4` However, what we should actually be doing (and what alpaca-farm and other rlhf implementations, and standard implementations of PPO do), is `improved: B1 B2 B3 B4 B1 B2 B3 B4 B1 B2 B3 B4` It would be even better if we actually produced new random batches at each ppo_epoch, that would require more refactoring. i.e.: `optimal: B1 B2 B3 B4 B1' B2' B3' B4' B1* B2* B3* B4*` This change basically just reorders the learning to make the code use the `improved` ordering above. It also renames n_updates_per_batch to n_inner_epochs as that's a more accurate description (especially now), adjusts forward_time and backward_time to not type-error, and renames mbs and mb to minibatch and microbatch (as that's what they are).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Splendid! Thank you, for this implementation detail has been left out. Here are comparisions between runs using this PR and the main branch https://wandb.ai/sorry/trlx-references/reports/main-v-main--Vmlldzo0ODc2MTAx (forgive the ambiguity with naming branches main
and main
, but respective top runs in wandb are made on this PR). Difference isn't too notable on most of these, however there is a difference for ILQL runs, but since self.n_inner_epochs=1
for it, this change shouldn't bear an effect there I suppose
This comment was marked as resolved.
This comment was marked as resolved.
Yeah you're right, that's what I meant by "on most of these", but also I rather meant that there are no obvious deteriorations, it was not in demeaning way towards this change, since it's obviously correct and in line with other implementations (Top ones are from this PR, terribly sorry for the confusion) |
This way we get better shuffling. Note that we now pass shuffle=True (implicitly) in the ppo trainer, whereas before we had shuffle=False. Shuffling is better here, as it means the gradient estimation over minibatches is less correlated.
@maxreciprocate Do we want to merge this now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although while there isn't difference for SFT runs, there is a difference for ILQL runs, however I can't figure out why exactly, since data ordering stays unchanged, as it should
https://wandb.ai/sorry/trlx-references/reports/main-v-main--Vmlldzo0OTQ2MzEw
def create_train_dataloader(self, shuffle=True, accelerate_prepare=True): | ||
dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=shuffle) | ||
if accelerate_prepare: | ||
dataloader = self.accelerator.prepare_dataloader(dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that accelerate_prepare
argument is redundant here, but also maybe the whole function implementation is, since it's overwritten individually for each trainer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make this an abstract function or a not-implemented function if it's ok for AccelerateRLTrainer
to be an abstract class or require subclassing, but I didn't know whether that was the case, so I put a standard implementation here instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's make it a abstract method just like the ones below it, with only difference between different implementations is that we either shard or not shard dataloader, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference is whether the dataloader is prepared by accelerate and whether you pass shuffle=True - ppo doesn't prepare and passess shuffle=True, ILQL and SFT do prepare and don't pass shuffle=True. I've made the change now.
It could be seeded differently as we pass the dataloaders to accelerate prepare in a different order than we did before? Or there's uncontrolled non-determinism that the seed isn't setting? |
Probably, that's most likely the reason, since I've check the order in which tensors come for ILQL and it hasn't changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent fix, greatly appreciated @RobertKirk!
Suppose you have 256 rollouts, and you batch them into batches B1, B2, B3, B4 (each of size 64). The order of gradient updates (assuming 3 ppo_epochs) is:
trlx: B1 B1 B1 B2 B2 B2 B3 B3 B3 B4 B4 B4
However, what we should actually be doing (and what alpaca-farm and other rlhf implementations, and standard implementations of PPO do), is
optimal: B1 B2 B3 B4 B1' B2' B3' B4' B1* B2* B3* B4*
This change reorders the learning to make the code use the
optimal
ordering above. It also renames n_updates_per_batch to n_inner_epochs as that's a more accurate description (especially now), adjusts forward_time and backward_time to not type-error, and renames mbs and mb to minibatch and microbatch (as that's what they are).