Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama NeMo support #542

Merged
merged 22 commits into from
Sep 22, 2023
32 changes: 32 additions & 0 deletions examples/llama_nemo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
### NeMo Megatron setup:

- Install NeMo version: v1.17.0

```bash
git clone https://github.com/NVIDIA/NeMo/
cd NeMo
git checkout d3017e4
pip install -e '.[all]'
```

- Install Apex:
```bash
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
```

### Convert LLaMa to NeMo:
Example:

```bash
python convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b
```

### Training:
Example: [wandb](https://wandb.ai/carperai/trlxnemo/runs/v7592y73?workspace=user-pvduy)

```bash
sbatch dist_train.sh
```
144 changes: 144 additions & 0 deletions examples/llama_nemo/convert_llama_to_nemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# flake8: noqa

import os
from pathlib import Path

import torch
from omegaconf.omegaconf import OmegaConf
from transformers import AutoModelForCausalLM


def main(args): # noqa: C901
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
print("Loaded model")

model_state_dict = model.state_dict()
print("Model loaded")

# Constants
TOTAL_LAYERS = model.config.num_hidden_layers
TOTAL_TP = args.total_tp
HIDDEN_DIM = model.config.hidden_size
FFN_HIDDEN_DIM = model.config.intermediate_size
PART_ATTN_DIM = HIDDEN_DIM // TOTAL_TP
PART_MLP_DIM = FFN_HIDDEN_DIM // TOTAL_TP
VOCAB_SIZE = model.config.vocab_size
EMBEDDING_DIM = VOCAB_SIZE // TOTAL_TP
OUTPUT_FOLDER = args.output_folder # NeMo converted checkpoint folder with llama weights

# Model Loading

def build_layer_mapping(layer):
return {
f"model.language_model.encoder.layers.{layer}.input_layernorm.weight": f"model.layers.{layer}.input_layernorm.weight",
f"model.language_model.encoder.layers.{layer}.self_attention.query_key_value.weight": [
f"model.layers.{layer}.self_attn.q_proj.weight",
f"model.layers.{layer}.self_attn.k_proj.weight",
f"model.layers.{layer}.self_attn.v_proj.weight",
],
f"model.language_model.encoder.layers.{layer}.self_attention.dense.weight": f"model.layers.{layer}.self_attn.o_proj.weight",
f"model.language_model.encoder.layers.{layer}.post_attention_layernorm.weight": f"model.layers.{layer}.post_attention_layernorm.weight",
f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h.weight": f"model.layers.{layer}.mlp.gate_proj.weight",
f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h_2.weight": f"model.layers.{layer}.mlp.up_proj.weight",
f"model.language_model.encoder.layers.{layer}.mlp.dense_4h_to_h.weight": f"model.layers.{layer}.mlp.down_proj.weight",
}

def save_nemo_state_dict(nemo_state_dict, tp_idx):
if TOTAL_TP == 1:
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/model_weights.ckpt")
else:
os.makedirs(f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}", exist_ok=True)
torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}/model_weights.ckpt")

def map_weights(tp_idx):
nemo_state_dict = {}

# Word embeddings mapping

nemo_state_dict["model.language_model.embedding.word_embeddings.weight"] = model_state_dict[
"model.embed_tokens.weight"
][tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, :]

nemo_state_dict["model.language_model.encoder.final_layernorm.weight"] = model_state_dict["model.norm.weight"]

nemo_state_dict["model.language_model.output_layer.weight"] = model_state_dict["lm_head.weight"][
tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, :
]

# Other layer mappings
for layer in range(TOTAL_LAYERS):
layer_mapping = build_layer_mapping(layer)
for k in layer_mapping.keys():
# original_size = nemo_state_dict[k].shape
if "self_attention.query_key_value.weight" in k:
nemo_state_dict[k] = get_self_attention_weight(model_state_dict, layer_mapping, k, tp_idx)
elif "self_attention.dense.weight" in k:
nemo_state_dict[k] = model_state_dict[layer_mapping[k]][
:, tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM
]
elif "mlp.dense_h_to_4h.weight" in k or "mlp.dense_h_to_4h_2.weight" in k:
nemo_state_dict[k] = get_mlp_weight(model_state_dict, layer_mapping, k, tp_idx)
elif "mlp.dense_4h_to_h.weight" in k:
nemo_state_dict[k] = model_state_dict[layer_mapping[k]][
:, tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM
]
elif isinstance(layer_mapping[k], torch.Tensor):
nemo_state_dict[k] = layer_mapping[k]
else:
nemo_state_dict[k] = model_state_dict[layer_mapping[k]]

# break view relationships otherwise pytorch will save original weights
# to back the slices
nemo_state_dict = {k: v.clone() for k, v in nemo_state_dict.items()}
save_nemo_state_dict(nemo_state_dict, tp_idx)

def get_self_attention_weight(model_state_dict, layer_mapping, key, tp_idx):
llama_query = model_state_dict[layer_mapping[key][0]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :]
llama_key = model_state_dict[layer_mapping[key][1]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :]
llama_value = model_state_dict[layer_mapping[key][2]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :]
return torch.cat([llama_query, llama_key, llama_value], dim=0)

def get_mlp_weight(model_state_dict, layer_mapping, key, tp_idx):
llama_weight = model_state_dict[layer_mapping[key]]
return llama_weight[tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM, :]

# dummy config

megatron_cfg_path = Path(__file__).parent / "megatron_7b_llama.yaml"

megatron_cfg = OmegaConf.load(megatron_cfg_path)
megatron_cfg.name = f"megatron_{args.name}"
megatron_cfg.trainer.num_nodes = 1
megatron_cfg.trainer.devices = TOTAL_TP
megatron_cfg.model.tensor_model_parallel_size = TOTAL_TP
megatron_cfg.model.padded_vocab_size = model.config.vocab_size
megatron_cfg.model.hidden_size = model.config.hidden_size
megatron_cfg.model.ffn_hidden_size = model.config.intermediate_size
megatron_cfg.model.num_layers = model.config.num_hidden_layers
megatron_cfg.model.num_attention_heads = model.config.num_attention_heads
megatron_cfg.model.max_position_embeddings = model.config.max_position_embeddings
megatron_cfg.model.seq_length = model.config.max_position_embeddings

megatron_cfg.exp_manager.create_wandb_logger = False
megatron_cfg.exp_manager.create_checkpoint_callback = False

print("Mapping weights")
for tp in range(TOTAL_TP):
map_weights(tp)

OmegaConf.save(megatron_cfg, str(Path(OUTPUT_FOLDER) / f"megatron_{args.name}.yaml"))

print("Done")


if __name__ == "__main__":
from argparse import ArgumentParser

args = ArgumentParser()
args.add_argument("--model_path", type=str, required=True)
args.add_argument("--output_folder", type=str, required=True)
args.add_argument("--total_tp", type=int, required=True)
args.add_argument("--name", type=str, required=True)
main(args.parse_args())
13 changes: 13 additions & 0 deletions examples/llama_nemo/dist_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
#SBATCH --job-name=llama
#SBATCH --partition=g40
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --mem=0
#SBATCH --cpus-per-task=8
#SBATCH --output=out.txt
#SBATCH --error=error.txt
#SBATCH --exclusive

cd examples/llama_nemo
srun --label python nemo_llama2_ppo_sentiments.py
115 changes: 115 additions & 0 deletions examples/llama_nemo/nemo_llama2_ppo_sentiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

from datasets import load_dataset
from transformers import DistilBertForSequenceClassification, pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def load_nemo_config():
"""Load nemo-megatron-1.3b model and trainer config"""
# Import here to not require nemo as a dependency
from omegaconf import OmegaConf

return OmegaConf.load("nemo_llama2_7b/megatron_7b.yaml")


def main(hparams={}):
# Merge sweep config with default config if given
default_config = TRLConfig.update(default_ppo_config().to_dict(), hparams)
nemo_config = load_nemo_config()
print(nemo_config)
cfg_name = "llama2-7b"
config = default_config.evolve(
train=dict(
total_steps=1600,
seq_length=256,
batch_size=16,
epochs=100,
eval_interval=100,
trainer="NeMoPPOTrainer",
trainer_kwargs=dict(
pretrained_model="nemo_llama2_7b/",
megatron_cfg=nemo_config,
),
checkpoint_interval=256,
checkpoint_dir=f"nemo_{cfg_name}_ppo_sentiments",
seed=2023,
project_name="trlxnemo",
tags=["nemo", "ppo", "sentiments", cfg_name],
),
optimizer=dict(
name="adamw",
kwargs=dict(
lr=1e-5,
weight_decay=1e-06,
eps=1.0e-8,
betas=(0.9, 0.95),
),
),
scheduler=dict(
name="CosineAnnealing",
),
model=dict(num_layers_unfrozen=24),
method=dict(
num_rollouts=128,
init_kl_coef=0.05,
vf_coef=1,
scale_reward="ignored",
gamma=1,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
gen_kwargs=dict(temperature=1.0, max_new_tokens=64),
chunk_size=64,
ppo_epochs=4,
),
)
config.scheduler.kwargs = dict(warmup_steps=0, constant_steps=1e12, min_lr=1e-6)

rank = int(os.environ["SLURM_PROCID"])
local_rank = rank % 8

reward_model = DistilBertForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb")
reward_model.to("cpu")
sentiment_fn = pipeline(
"sentiment-analysis",
model=reward_model, # "lvwerra/distilbert-imdb",
tokenizer="lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=local_rank,
)

def reward_fn(samples: List[str], **kwargs) -> List[float]:
reward_model.to(local_rank)
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
reward_model.to("cpu")
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
30 changes: 28 additions & 2 deletions trlx/models/modeling_nemo_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
MegatronGPTModel,
)
from nemo.collections.nlp.modules.common.megatron.attention import ParallelAttention
from nemo.collections.nlp.modules.common.megatron.module import (
Float16Module,
MegatronModule,
Expand Down Expand Up @@ -58,6 +59,11 @@
_PER_DP_RANK_RNG = "per-data-parallel-rank-rng"


def patch_attention_for_llama(m):
if isinstance(m, ParallelAttention):
m.megatron_legacy = True


class ParallelLinear(nn.Module):
"""Linear layer parallelized over the longer dimension."""

Expand Down Expand Up @@ -181,7 +187,13 @@ def __init__(self, language_model, other_heads, build_reference_model=True):
self.reference_model_offloaded = True

self.other_heads = other_heads
self.word_embeddings = language_model.word_embeddings
if hasattr(language_model, "output_layer"):
self.output_layer = self._lm.language_model.output_layer
self.word_embeddings = self.output_layer.weight
else:
if hasattr(language_model, "word_embeddings"):
self.word_embeddings = language_model.word_embeddings
self.output_layer = None

# The tensor from the previous pipeline rank arrives via this method
def set_input_tensor(self, input_tensor):
Expand All @@ -194,6 +206,15 @@ def load_state_dict(self, lm_state_dict, strict=True):
"""Load GPTModel state dict."""
self.language_model.load_state_dict(lm_state_dict, strict=strict)

if "output_layer.weight" in lm_state_dict:
dtype = lm_state_dict["output_layer.weight"].dtype
device = self.language_model.output_layer.weight.device
params = torch.nn.Parameter(
lm_state_dict["output_layer.weight"].to(device, dtype=dtype), requires_grad=True
)
self.language_model.output_layer.weight = params
print("Loaded output_layer.weight from lm_state_dict")

if self.build_reference_model:
for p in self.language_model.parameters():
if p.requires_grad:
Expand Down Expand Up @@ -232,7 +253,10 @@ def forward(
run_value_head=False,
**kwargs,
):
logit_weights = self._lm.word_embeddings_weight()
if hasattr(self._lm.language_model, "output_layer"):
logit_weights = self._lm.language_model.output_layer.weight
else:
logit_weights = self._lm.word_embeddings_weight()

if run_policy_model:
self.offload_reference_model()
Expand Down Expand Up @@ -501,6 +525,8 @@ def freeze_layers(m):
p.requires_grad_(False)
gpt.language_model.apply(freeze_layers)

if self.cfg.get("megatron_legacy", False):
gpt.apply(patch_attention_for_llama)
# If running on the last pipeline stage, add the PPO value head and hydra reference model
if post_process:
value_head = ValueHead(self.cfg.hidden_size, self.cfg.sequence_parallel)
Expand Down
8 changes: 8 additions & 0 deletions trlx/models/modeling_nemo_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def trim_key(key, prefix):
unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True)
print(f"Loaded from pretrained {rank_params}")

def model_provider_func(self, *args, **kwargs):
gpt = super().model_provider_func(*args, **kwargs)

from trlx.models.modeling_nemo_ppo import patch_attention_for_llama

gpt.apply(patch_attention_for_llama)
return gpt

# Adapted from NeMo
# https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259
def training_step(self, batch: List[torch.Tensor], batch_idx: int): # noqa: C901
Expand Down
Loading