Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PPO log_ratio bug #509

Merged
merged 4 commits into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,14 @@ def loss(self, batch: PPORLBatch):
logprobs, values_pred, mask = (
logprobs[:, start:end],
values_pred[:, start:end],
mask[:, start:end],
mask[:, start + 1 : end + 1],
)
else:
tokens = torch.cat((query_tensors, response_tensors), dim=1)
attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device)
outputs = self.model(tokens, attention_mask, return_dict=True)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
outputs = self.model(tokens, attention_mask, return_dict=True, position_ids=position_ids)
logits = outputs.logits
values_pred = outputs.value
values_pred = values_pred[:, :-1]
Expand All @@ -191,7 +193,7 @@ def loss(self, batch: PPORLBatch):
logprobs, values_pred, mask = (
logprobs[:, start:end],
values_pred[:, start:end],
attention_mask[:, start:end],
attention_mask[:, start + 1 : end + 1],
)

loss, stats = self.config.method.loss(
Expand Down Expand Up @@ -398,22 +400,25 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
else:
all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1)
attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
with torch.no_grad():
logits, *_, values = self.model(
all_tokens,
attention_mask=attention_mask,
all_tokens, attention_mask=attention_mask, position_ids=position_ids
)
# TODO(dahoas): When hydra model works need to also support generation on hydra head
if hasattr(self.model, "frozen_head"):
ref_logits = self.model.forward_hydra(
all_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
).logits
else:
ref_logits = self.ref_model(
all_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate this change into the self.model.forward_hydra call as well, otherwise log_ratio computed inside make_experience isn't equal to zero initially

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also be duplicated to the if self.config.model.model_arch_type == "seq2seq" branches?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I realize the forward methods for seq2seq models lack the position_ids argument, and at least T5 uses relative positional biases AFAIK, not absolute, in which case this should not be a problem for that model at least. I'm not sure whether there are other seq2seq models with absolute pos embeddings that TRLX support?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct, sorry for the confusion, there aren't any currently except T5

return_dict=True,
).logits
ref_logits = ref_logits.to(device)
Expand Down