Skip to content

Commit

Permalink
Dense reward carper (Fine grained feedback) (#514)
Browse files Browse the repository at this point in the history
* Implementing support for dense rewards

* Fix distributed ref_mean, ref_var bug for dense rewards

* Fix black

* Remove annoying comments

* Fixing reward padding, simplifying running_moment updates

* Fixing style

* Fix missing dtype in trainer rewards tensor (#520)

* fix(ppo_randomwalks): `reward_fn` signature to accommodate tokenizer

* Rename example + fix nits (#527)

---------

Co-authored-by: Glavin Wiechert <[email protected]>
Co-authored-by: maxreciprocate <[email protected]>
  • Loading branch information
3 people committed Jul 19, 2023
1 parent ba947e5 commit 0c94ee8
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 17 deletions.
75 changes: 75 additions & 0 deletions examples/ppo_dense_sentiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def get_negative_score(scores):
return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"]


def main(hparams={}):
# Merge sweep config with default config if given
config = TRLConfig.update(default_ppo_config().to_dict(), hparams)

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
device = -1

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)

def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], tokenizer, **kwargs) -> List[float]:
# Reward positively for initially negative then positive review
# Reward functions should never receive padded text except for a single EOS at the end
# Reward function should return token rewards for just the response
first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples]
negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves)))
second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples]
positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves)))
text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)]
tok_scores = []
for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores):
toks = tokenizer(response).input_ids
tok_score = [0] * len(toks)
tok_score[len(tok_score) // 2] = text_score[0]
tok_score[-1] = text_score[1]
tok_scores.append(tok_score)
return tok_scores

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

trlx.train(
reward_fn=dense_reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
4 changes: 2 additions & 2 deletions examples/randomwalks/ppo_randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def main(hparams={}):
trlx.train(
# An "optimality" reward function is used, with scores in [0,1]
# depending on how close the path is to the shortest possible path.
reward_fn=lambda samples, prompts, outputs: metric_fn(samples)["optimality"],
reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"],
# The prompts are simply the first nodes (represented as letters) to
# start from.
prompts=prompts,
eval_prompts=prompts,
metric_fn=lambda samples, prompts, outputs: metric_fn(samples),
metric_fn=lambda samples, **kwargs: metric_fn(samples),
config=config,
)

Expand Down
15 changes: 12 additions & 3 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,19 @@ def evaluate(self): # noqa: C901
# in online setting, compute the reward for validation
if self.reward_fn:
logger.info("Computing rewards")
rewards = torch.tensor(
self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata),
dtype=float,
rewards = self.reward_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
model_tok=self.tokenizer,
**metadata,
)
if isinstance(rewards[0], torch.Tensor):
rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float)
elif isinstance(rewards[0], list):
rewards = torch.tensor([sum(reward) for reward in rewards], dtype=float)
else:
rewards = torch.tensor(rewards, dtype=float)
mean_reward = rewards.mean().item()
columns.append("reward")
if not isinstance(rewards, list):
Expand Down
54 changes: 42 additions & 12 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from time import time
from typing import Callable, List

import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

Expand Down Expand Up @@ -297,24 +299,39 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
)

rollout_score_time = time()
all_scores = torch.tensor(
self.reward_fn(
samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata
),
dtype=torch.float,
device=device,
# reward_fn should return list of rewards at each token per sample
# NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed)
all_scores = self.reward_fn(
samples=all_str_samples,
prompts=all_str_prompts,
outputs=all_str_outputs,
tokenizer=self.tokenizer,
**metadata,
)
all_scores = [
torch.tensor(score, dtype=torch.float, device=device).view(
-1,
)
for score in all_scores
]
# Pad 0 reward on the ends
all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf)
max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device)

stats["time/rollout_score"] = time() - rollout_score_time

all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind())
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
else:
all_scores = None
max_len = torch.tensor(0, dtype=torch.long, device=device)

if torch.distributed.is_initialized():
scores = torch.empty(len(samples), device=device)
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores)
else:
scores = all_scores[0].clone().detach()
scores_mask = scores != -np.inf

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)

Expand Down Expand Up @@ -342,8 +359,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

# store statistics of the initial rollout as reference
if self.ref_mean is None:
self.ref_mean, self.ref_std = scores.mean(), scores.std()
all_scores_mean, all_scores_std = self.running_moments.update(scores)
self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(
dim=1
).std()
all_scores_mean, all_scores_std = self.running_moments.update(torch.sum(scores * scores_mask, dim=1))
stats["rollout_scores/mean"] = all_scores_mean.item()
stats["rollout_scores/std"] = all_scores_std.item()
stats["rollout_scores/running_mean"] = self.running_moments.mean.item()
Expand Down Expand Up @@ -415,6 +434,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])
else:
# NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])

Expand All @@ -439,7 +459,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
values = values.cpu()[:, :-1]

# Get the logprobs and values, for tokens that are not padding,
# from the start of the prompt up to the <eos> token, while also including the latter
# from the end of the prompt up to the <eos> token, while also including the latter
# (these are taken from the student model and not the reference model)
ends = start + attention_mask[:, start:].sum(1) + 1
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
Expand All @@ -452,7 +472,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

for sample_idx in range(n_samples):
rewards = kl_penalty[sample_idx]
rewards[-1] += scores[sample_idx].cpu()
# Then add in rewards
if scores.shape[1] == 1:
# NOTE: Final reward given at EOS token following HHH practice
rewards[-1] += scores[sample_idx][0].cpu()
else:
score = scores[sample_idx]
score_right_padding = torch.sum(scores_mask[sample_idx])
score = score[:score_right_padding].cpu()
p_score = torch.zeros_like(rewards)
p_score[: score.shape[0]] += score
rewards += p_score

ppo_rl_elements.append(
PPORLElement(
Expand Down

0 comments on commit 0c94ee8

Please sign in to comment.