Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parallel reward function #575

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Jingru
Copy link
Contributor

@Jingru Jingru commented Oct 24, 2023

Currently, reward_fn is invoked only in main_process, which will hang if the reward_fn is actually a parallel model (like: TP/PP/Zero optimized one).

@codecov-commenter
Copy link

codecov-commenter commented Oct 24, 2023

Codecov Report

Attention: 26 lines in your changes are missing coverage. Please review.

Comparison is base (91a0f43) 43.58% compared to head (e0a6ba2) 43.62%.
Report is 1 commits behind head on main.

❗ Current head e0a6ba2 differs from pull request most recent head 2cf8af5. Consider uploading reports for the commit 2cf8af5 to get more accurate results

Files Patch % Lines
trlx/trainer/accelerate_base_trainer.py 51.35% 18 Missing ⚠️
trlx/trainer/accelerate_ppo_trainer.py 68.00% 8 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #575      +/-   ##
==========================================
+ Coverage   43.58%   43.62%   +0.03%     
==========================================
  Files          33       33              
  Lines        4974     4997      +23     
==========================================
+ Hits         2168     2180      +12     
- Misses       2806     2817      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@lzy37ld
Copy link

lzy37ld commented Oct 29, 2023

@Jingru ,Hi Jingru, any updates here? I am also looking for this approach. Thanks!

@lzy37ld
Copy link

lzy37ld commented Oct 29, 2023

Actually, I am bit confused about, why they have to separate the reward model to the last GPU? The reward model could not do the parallel like the policy model?

@Jingru
Copy link
Contributor Author

Jingru commented Oct 29, 2023

@Jingru ,Hi Jingru, any updates here? I am also looking for this approach. Thanks!

I believe this pr could solve this issue. I've already tested it myself.

Maybe this abstraction of reward "function" rather than reward "model" is the reason why it is only invoked in ONE process/GPU.

@Jingru
Copy link
Contributor Author

Jingru commented Oct 29, 2023

Another improvement may be that the hydra lora architecture could also share the same foundation frozen model weights between the "reward model" as well as the "actor model" and the "critic model"

@lzy37ld
Copy link

lzy37ld commented Oct 29, 2023

Thanks!
Actually I am new to distributed training.. So I think my question is sort of like why they don't concat the reward model into
all_model = All_Model(actor_model, critic_model, reward_model) and just accelerator.prepare(all_model)?

If I understand correctly and suppose we got two processes 0,1. For each process, it just go through the distributed model to get the generation, and then it would get the reward for each process as well. Why do they need to separate them and use some tricks like gather, broadcast etc... I am confused about it :<

Much appreciation if you have time to reply

@lzy37ld
Copy link

lzy37ld commented Oct 29, 2023

For example, I stop my debugger at this point:

return self.accelerator.unwrap_model(self.model).generate(

But find that self.model(the whole model instead of the partioned one) would have a different devices for different processes. That's not the case, right? As here said that zero-DP would split the model into different slices.

@Jingru
Copy link
Contributor Author

Jingru commented Oct 30, 2023

For example, I stop my debugger at this point:

return self.accelerator.unwrap_model(self.model).generate(

But find that self.model(the whole model instead of the partioned one) would have a different devices for different processes. That's not the case, right? As here said that zero-DP would split the model into different slices.

This is expected behavior of deepspeed: every process has a model (wrapper), and each wrapper only has a shard of model weights if zero3 is enabled. Besides, inputs(prompts) and outputs are also sharded between different processes (for DP). That's why gather is needed if only one process executes rewarding. So, if we use zero3-enabled reward model, we don't have to invoke gather and scatter any more.

@maxreciprocate
Copy link
Collaborator

Hi @Jingru! Thanks for the work you've done here! Could you share the script with a parallel or sharded reward model you had used?

@Jingru
Copy link
Contributor Author

Jingru commented Oct 31, 2023

Hi @Jingru! Thanks for the work you've done here! Could you share the script with a parallel or sharded reward model you had used?

I borrowed reward model initialization function from deepspeed-chat and defined reward_fn like this:


rw_tokenizer = AutoTokenizer.from_pretrained("llama_critic")

ds_eval_config = get_eval_ds_config(offload=True, stage=3)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config["train_micro_batch_size_per_gpu"] = BATCH_SIZE_PER_GPU
ds_eval_config["train_batch_size"] = BATCH_SIZE_PER_GPU * 8 * 1

rw_model, *_ = deepspeed.initialize(
    model=create_critic_model(
        model_name_or_path="llama_critic",
        tokenizer=rw_tokenizer,
        ds_config=ds_eval_config,
        num_padding_at_beginning=0,
        rlhf_training=True,
        disable_dropout=True,
        zero_stage=3,
    ),
    config=ds_eval_config,
)


def get_scores(samples: List[str], prompt_length: int = 0):
    samples = ["<|startoftext|>" + chosen + "<|endoftext|>" for chosen in samples]
    encodings_dict = rw_tokenizer(
        samples,
        truncation=True,
        max_length=config.train.seq_length,
        padding="max_length",
        return_tensors="pt",
    )
    input_ids = encodings_dict["input_ids"].to(rw_model.rwtranrsformer.device)
    attn_masks = encodings_dict["attention_mask"].to(rw_model.rwtranrsformer.device)
    # input_ids = input_ids.repeat(2, 1)
    # attn_masks = attn_masks.repeat(2, 1)
    with torch.no_grad():
        scores = rw_model.forward_value(
            input_ids=input_ids, attention_mask=attn_masks, prompt_length=prompt_length
        )
    return scores["chosen_end_scores"]


def reward_fn(samples: List[str], **kwargs):
    prompts = kwargs["prompts"]
    prompt_length = len(prompts[0])
    original_outputs = kwargs["original_output"]
    original_scores = get_scores(
        [prompt + output for prompt, output in zip(prompts, original_outputs)],
        prompt_length,
    )
    scores = get_scores(samples, prompt_length)
    norms_scores = scores - original_scores
    return norms_scores

I believe we can just use deepspeed.initialize to modify reward_fn of the provided example summarize_rlhf and make it paralleled.

@lzy37ld
Copy link

lzy37ld commented Oct 31, 2023

Looks like reward in parallel is more efficient? As we don't need to gather or broadcast to because the reward model is on the last device anymore..

@lzy37ld
Copy link

lzy37ld commented Oct 31, 2023

But how much VRAM does stage 3 need?
Say I just have two relatively small GPUs, like 5GB for each.
If our model is 6GB, and after partition, each GPU needs 3GB VRAM, then that means it would OOM?(suppose no offload_params)

Reason for why I ask about it is that I just found that the model would be put on each device without partition when stage 3 is enabled. Probably they only do partition during the forward function. But if models are already fitted in two devices, why we need partition anymore..

The code I test is from HF(pls see line I print for device and parameters):


#!/usr/bin/env python

# This script demonstrates how to use Deepspeed ZeRO in an inference mode when one can't fit a model
# into a single GPU
#
# 1. Use 1 GPU with CPU offload
# 2. Or use multiple GPUs instead
#
# First you need to install deepspeed: pip install deepspeed
#
# Here we use a 3B "bigscience/T0_3B" model which needs about 15GB GPU RAM - so 1 largish or 2
# small GPUs can handle it. or 1 small GPU and a lot of CPU memory.
#
# To use a larger model like "bigscience/T0" which needs about 50GB, unless you have an 80GB GPU -
# you will need 2-4 gpus. And then you can adapt the script to handle more gpus if you want to
# process multiple inputs at once.
#
# The provided deepspeed config also activates CPU memory offloading, so chances are that if you
# have a lot of available CPU memory and you don't mind a slowdown you should be able to load a
# model that doesn't normally fit into a single GPU. If you have enough GPU memory the program will
# run faster if you don't want offload to CPU - so disable that section then.
#
# To deploy on 1 gpu:
#
# deepspeed --num_gpus 1 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=1 t0.py
#
# To deploy on 2 gpus:
#
# deepspeed --num_gpus 2 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=2 t0.py


from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers.integrations import HfDeepSpeedConfig
import deepspeed
import os
import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To avoid warnings about parallelism in tokenizers

# distributed setup
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()

model_name = "bigscience/T0_3B"

config = AutoConfig.from_pretrained(model_name)
model_hidden_size = config.d_model

# batch size has to be divisible by world_size, but can be bigger than world_size
train_batch_size = 1 * world_size

# ds_config notes
#
# - enable bf16 if you use Ampere or higher GPU - this will run in mixed precision and will be
# faster.
#
# - for older GPUs you can enable fp16, but it'll only work for non-bf16 pretrained models - e.g.
# all official t5 models are bf16-pretrained
#
# - set offload_param.device to "none" or completely remove the `offload_param` section if you don't
# - want CPU offload
#
# - if using `offload_param` you can manually finetune stage3_param_persistence_threshold to control
# - which params should remain on gpus - the larger the value the smaller the offload size
#
# For indepth info on Deepspeed config see
# https://huggingface.co/docs/transformers/main/main_classes/deepspeed

# keeping the same format as json for consistency, except it uses lower case for true/false
# fmt: off
ds_config = {
    "fp16": {
        "enabled": False
    },
    "bf16": {
        "enabled": False
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },
    "steps_per_print": 2000,
    "train_batch_size": train_batch_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}
# fmt: on

# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
#
# **it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)**
#
# otherwise the model will first be loaded normally and only partitioned at forward time which is
# less efficient and when there is little CPU RAM may fail
dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive

# now a model can be loaded.
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# initialise Deepspeed ZeRO and store only the engine object
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
# I just print their params and device here:
# I just print their params and device here:
# I just print their params and device here:
# I just print their params and device here:
# I just print their params and device here:

print(ds_engine)
print(ds_engine.device)

ds_engine.module.eval()  # inference
# I just print their params and device here:
# I just print their params and device here:
# I just print their params and device here:
# I just print their params and device here:
# I just print their params and device here:

# Deepspeed ZeRO can process unrelated inputs on each GPU. So for 2 gpus you process 2 inputs at once.
# If you use more GPUs adjust for more.
# And of course if you have just one input to process you then need to pass the same string to both gpus
# If you use only one GPU, then you will have only rank 0.
rank = torch.distributed.get_rank()
if rank == 0:
    text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif rank == 1:
    text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"

tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
with torch.no_grad():
    outputs = ds_engine.module.generate(inputs, synced_gpus=True)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"rank{rank}:\n   in={text_in}\n  out={text_out}")


@Jingru
Copy link
Contributor Author

Jingru commented Oct 31, 2023

I just found that the model would be put on each device without partition when stage 3 is enabled

You can set --zero3_init_flag=true for accelerate to make deepspeed init large model by sharding directly rather than load the whole model on ONE GPU then partition it.

@lzy37ld
Copy link

lzy37ld commented Oct 31, 2023

Thanks for this! @Jingru .

I carefully checked it again and feel like this is already being shared as the way --zero3_init_flag would do, if you see the comments in the core:

'
next line instructs transformers to partition the model directly over multiple gpus using

deepspeed.zero.Init

when model's from_pretrained method is called.

it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)

otherwise the model will first be loaded normally and only partitioned at forward time which is
less efficient and when there is little CPU RAM may fail
'

Besides, I watch the parameter of the model after from_pretrained()
image
and looks like they are empty tensor. So my question is, what are the times that occupy the VRAM?
(I set a debugger after from_pretrained function and watch the GPU resources)
image
Why they are 7k and 22k, which in total is 30k which is larger than the size of T0-3B(3 * 4 = 12k)...
That's weird...

@Jingru
Copy link
Contributor Author

Jingru commented Nov 4, 2023

Thanks for this! @Jingru .

I carefully checked it again and feel like this is already being shared as the way --zero3_init_flag would do, if you see the comments in the core:

' next line instructs transformers to partition the model directly over multiple gpus using

deepspeed.zero.Init

when model's from_pretrained method is called.

it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)

otherwise the model will first be loaded normally and only partitioned at forward time which is less efficient and when there is little CPU RAM may fail '

Besides, I watch the parameter of the model after from_pretrained() image and looks like they are empty tensor. So my question is, what are the times that occupy the VRAM? (I set a debugger after from_pretrained function and watch the GPU resources) image Why they are 7k and 22k, which in total is 30k which is larger than the size of T0-3B(3 * 4 = 12k)... That's weird...

You're right about zero3_init_flag.

Deepspeed has its own VRAM manager and tensors in model are just placeholders, so they are empty after model initialization. I'm not familiar with model sharding details of deepseed zero3 and you may check the implementation of deepspeed about the imbalance of model sharding.

As for VRAM occupation, I guess it contains optimizer states for every parameter and they may be float32 if you're using mixed precision.

@lzy37ld
Copy link

lzy37ld commented Nov 7, 2023

Oh If I understand correctly, accelerate could not use prepare twice in a script, so that's why you use deepspeed.initialize

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants