diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2f3ab8022..e564e85a7 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -175,12 +175,14 @@ def loss(self, batch: PPORLBatch): logprobs, values_pred, mask = ( logprobs[:, start:end], values_pred[:, start:end], - mask[:, start:end], + mask[:, start + 1 : end + 1], ) else: tokens = torch.cat((query_tensors, response_tensors), dim=1) attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) - outputs = self.model(tokens, attention_mask, return_dict=True) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + outputs = self.model(tokens, attention_mask, return_dict=True, position_ids=position_ids) logits = outputs.logits values_pred = outputs.value values_pred = values_pred[:, :-1] @@ -191,7 +193,7 @@ def loss(self, batch: PPORLBatch): logprobs, values_pred, mask = ( logprobs[:, start:end], values_pred[:, start:end], - attention_mask[:, start:end], + attention_mask[:, start + 1 : end + 1], ) loss, stats = self.config.method.loss( @@ -398,22 +400,25 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) with torch.no_grad(): logits, *_, values = self.model( - all_tokens, - attention_mask=attention_mask, + all_tokens, attention_mask=attention_mask, position_ids=position_ids ) # TODO(dahoas): When hydra model works need to also support generation on hydra head if hasattr(self.model, "frozen_head"): ref_logits = self.model.forward_hydra( all_tokens, attention_mask=attention_mask, + position_ids=position_ids, return_dict=True, ).logits else: ref_logits = self.ref_model( all_tokens, attention_mask=attention_mask, + position_ids=position_ids, return_dict=True, ).logits ref_logits = ref_logits.to(device)