Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parallel reward function #575

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ class TrainConfig:

minibatch_size: Optional[int] = None

reward_only_in_main_process: bool = True

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
90 changes: 56 additions & 34 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,19 +387,23 @@ def evaluate(self): # noqa: C901
if self.config.model.model_arch_type == "seq2seq":
samples = samples[:, 1:].contiguous()

prompt_sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(len(prompts.input_ids))
prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
self.accelerator.pad_across_processes(
[prompts.input_ids, samples, prompt_sizes.to(samples.device)],
dim=1,
pad_index=self.tokenizer.pad_token_id,
)
prompt_sizes = torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(
len(prompts.input_ids)
)
if self.config.train.reward_only_in_main_process:
prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
self.accelerator.pad_across_processes(
[prompts.input_ids, samples, prompt_sizes],
dim=1,
pad_index=self.tokenizer.pad_token_id,
)
)
metadata = gather_dict(metadata, self.accelerator.gradient_state)
else:
prompts = prompts.input_ids
all_samples.extend(samples.tolist())
all_prompts.extend(prompts.tolist())
all_prompt_sizes.extend(prompt_sizes.tolist())

metadata = gather_dict(metadata, self.accelerator.gradient_state)
all_metadata.append(metadata)

desc = [
Expand All @@ -412,11 +416,16 @@ def evaluate(self): # noqa: C901

stats["time/generate"] = time() - generate_time

if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes)

columns = ["prompt", "output"]
if self.accelerator.is_main_process:
columns = ["prompt", "output"]

# gather should be invoked in every process, not just the main process
columns_data = [str_prompts, str_outputs]
if not self.config.train.reward_only_in_main_process:
columns_data = self.accelerator.gather_for_metrics(columns_data)

metadata, *xs = all_metadata
for k in metadata:
Expand All @@ -439,41 +448,54 @@ def evaluate(self): # noqa: C901
rewards = torch.tensor([sum(reward) for reward in rewards], dtype=float)
else:
rewards = torch.tensor(rewards, dtype=float)
mean_reward = rewards.mean().item()
columns.append("reward")
if not isinstance(rewards, list):
rewards = rewards.tolist()
columns_data.append(rewards)
stats[f"reward/mean{sweep_suffix}"] = mean_reward

# gather should be invoked in every process, not just the main process
if not self.config.train.reward_only_in_main_process:
rewards = self.accelerator.gather(rewards)

if self.accelerator.is_main_process:
mean_reward = rewards.mean().item()

columns.append("reward")
if not isinstance(rewards, list):
rewards = rewards.tolist()
columns_data.append(rewards)
stats[f"reward/mean{sweep_suffix}"] = mean_reward

# additionally log any other metrics
if self.metric_fn:
logger.info("Computing metrics")
metric_time = time()
metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata)
stats["time/metric"] = time() - metric_time
if not self.config.train.reward_only_in_main_process:
metrics = self.accelerator.gather_for_metrics(metrics)

mean_metrics = {
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items()
}
if self.accelerator.is_main_process:
stats["time/metric"] = time() - metric_time

stats.update(mean_metrics)
mean_metrics = {
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item()
for k, xs in metrics.items()
}

for metric, values in metrics.items():
# Skip metrics that are scalers since they represent aggregated values
if isinstance(values, float):
continue
columns.append(metric)
if not isinstance(values, list):
values = values.tolist()
columns_data.append(values)
stats.update(mean_metrics)

for metric, values in metrics.items():
# Skip metrics that are scalers since they represent aggregated values
if isinstance(values, float):
continue
columns.append(metric)
if not isinstance(values, list):
values = values.tolist()
columns_data.append(values)

# Prepend the sweep argument along with samples
if self.generate_sweep_kwarg:
columns.insert(0, gen_sweep_arg)
columns_data.insert(0, [gen_sweep_value] * len(samples))
if self.accelerator.is_main_process:
if self.generate_sweep_kwarg:
columns.insert(0, gen_sweep_arg)
columns_data.insert(0, [gen_sweep_value] * len(samples))

table.append(list(zip(*columns_data)))
table.append(list(zip(*columns_data)))

# Log and display evaluation metrics
logger.info("Summarizing evaluation")
Expand Down
58 changes: 37 additions & 21 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,25 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
device = samples.device

prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"})
metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}

if self.config.train.reward_only_in_main_process:
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict(metadata)
else:
gathered_samples = samples
gathered_prompts = prompt_tensors
gathered_prompt_sizes = prompt_sizes

if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
all_str_samples, all_str_prompts, all_str_outputs = self.decode(
gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True
)
Expand All @@ -316,9 +323,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
**metadata,
)
all_scores = [
torch.tensor(score, dtype=torch.float, device=device).view(
-1,
)
score.view(-1)
if isinstance(score, torch.Tensor)
else torch.tensor(score, dtype=torch.float, device=device).view(-1)
for score in all_scores
]
# Pad 0 reward on the ends
Expand All @@ -327,20 +334,29 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

stats["time/rollout_score"] = time() - rollout_score_time

all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
if self.config.train.reward_only_in_main_process:
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
else:
all_scores = None
max_len = torch.tensor(0, dtype=torch.long, device=device)

if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores)
if self.config.train.reward_only_in_main_process:
if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores) # scores is one shard of one process after scatter
else:
scores = all_scores[0].clone().detach() # shard of one process
else:
scores = all_scores[0].clone().detach()
scores = all_scores.clone().detach() # shard of one process
# `all_scores` no longer used, no need to gather it
scores_mask = scores != -np.inf

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
if self.config.train.reward_only_in_main_process:
_, _, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
else:
str_outputs = all_str_outputs
# `all_str_outputs` no longer used, no need to gather it

# Pad the sample outputs
outputs = self.tokenizer(str_outputs).input_ids
Expand Down
Loading