diff --git a/examples/ppo_dense_sentiments.py b/examples/ppo_dense_sentiments.py new file mode 100644 index 000000000..6f6601123 --- /dev/null +++ b/examples/ppo_dense_sentiments.py @@ -0,0 +1,75 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def get_negative_score(scores): + return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], tokenizer, **kwargs) -> List[float]: + # Reward positively for initially negative then positive review + # Reward functions should never receive padded text except for a single EOS at the end + # Reward function should return token rewards for just the response + first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] + negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] + positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) + text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] + tok_scores = [] + for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): + toks = tokenizer(response).input_ids + tok_score = [0] * len(toks) + tok_score[len(tok_score) // 2] = text_score[0] + tok_score[-1] = text_score[1] + tok_scores.append(tok_score) + return tok_scores + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=dense_reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index c3203775c..7c897e2c5 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -59,12 +59,12 @@ def main(hparams={}): trlx.train( # An "optimality" reward function is used, with scores in [0,1] # depending on how close the path is to the shortest possible path. - reward_fn=lambda samples, prompts, outputs: metric_fn(samples)["optimality"], + reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"], # The prompts are simply the first nodes (represented as letters) to # start from. prompts=prompts, eval_prompts=prompts, - metric_fn=lambda samples, prompts, outputs: metric_fn(samples), + metric_fn=lambda samples, **kwargs: metric_fn(samples), config=config, ) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5c82335c0..cf3b58c5e 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -427,10 +427,19 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = torch.tensor( - self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata), - dtype=float, + rewards = self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + model_tok=self.tokenizer, + **metadata, ) + if isinstance(rewards[0], torch.Tensor): + rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) + elif isinstance(rewards[0], list): + rewards = torch.tensor([sum(reward) for reward in rewards], dtype=float) + else: + rewards = torch.tensor(rewards, dtype=float) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a3af9aa3f..34ceb5e16 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -4,9 +4,11 @@ from time import time from typing import Callable, List +import numpy as np import torch import torch.nn.functional as F import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -297,24 +299,39 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) rollout_score_time = time() - all_scores = torch.tensor( - self.reward_fn( - samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata - ), - dtype=torch.float, - device=device, + # reward_fn should return list of rewards at each token per sample + # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) + all_scores = self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + tokenizer=self.tokenizer, + **metadata, ) + all_scores = [ + torch.tensor(score, dtype=torch.float, device=device).view( + -1, + ) + for score in all_scores + ] + # Pad 0 reward on the ends + all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf) + max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device) + stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind()) else: all_scores = None + max_len = torch.tensor(0, dtype=torch.long, device=device) if torch.distributed.is_initialized(): - scores = torch.empty(len(samples), device=device) + torch.distributed.broadcast(max_len, 0) + scores = torch.empty((len(samples), max_len), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() + scores_mask = scores != -np.inf str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -342,8 +359,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.mean(), scores.std() - all_scores_mean, all_scores_std = self.running_moments.update(scores) + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum( + dim=1 + ).std() + all_scores_mean, all_scores_std = self.running_moments.update(torch.sum(scores * scores_mask, dim=1)) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() stats["rollout_scores/running_mean"] = self.running_moments.mean.item() @@ -415,6 +434,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) else: + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) @@ -439,7 +459,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq values = values.cpu()[:, :-1] # Get the logprobs and values, for tokens that are not padding, - # from the start of the prompt up to the token, while also including the latter + # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) ends = start + attention_mask[:, start:].sum(1) + 1 all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] @@ -452,7 +472,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq for sample_idx in range(n_samples): rewards = kl_penalty[sample_idx] - rewards[-1] += scores[sample_idx].cpu() + # Then add in rewards + if scores.shape[1] == 1: + # NOTE: Final reward given at EOS token following HHH practice + rewards[-1] += scores[sample_idx][0].cpu() + else: + score = scores[sample_idx] + score_right_padding = torch.sum(scores_mask[sample_idx]) + score = score[:score_right_padding].cpu() + p_score = torch.zeros_like(rewards) + p_score[: score.shape[0]] += score + rewards += p_score ppo_rl_elements.append( PPORLElement(