-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PPO log_ratio bug #509
Conversation
Note that this PR only sets position_ids and shifts the mask for non seq2seq models. I have only tried this on a gpt2 model, and is not sure whether this bug also applies to seq2seq models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Tobias! This is an extremely valuable find
@@ -414,6 +435,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq | |||
ref_logits = self.ref_model( | |||
all_tokens, | |||
attention_mask=attention_mask, | |||
position_ids=position_ids, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate this change into the self.model.forward_hydra
call as well, otherwise log_ratio
computed inside make_experience
isn't equal to zero initially
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it also be duplicated to the if self.config.model.model_arch_type == "seq2seq"
branches?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I realize the forward
methods for seq2seq models lack the position_ids
argument, and at least T5 uses relative positional biases AFAIK, not absolute, in which case this should not be a problem for that model at least. I'm not sure whether there are other seq2seq models with absolute pos embeddings that TRLX support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're correct, sorry for the confusion, there aren't any currently except T5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://wandb.ai/sorry/trlx-references/reports/fix-logratio-bug-v-main--Vmlldzo0NzE2NzYw
Thanks again, Tobias
LGTM!
Relevant issue: #508
position_ids
when computing logprobs, both inmake_experience
and inloss
to ensure same absolute positional embeddings are used in the two methods.logratio
should be shifted by one, to correctly mask the last token in the batch.Note: Remember to remove the debug print statements before merge.