Skip to content

Commit

Permalink
Llama NeMo support (#542)
Browse files Browse the repository at this point in the history
* init"

* fix logging bug

* sync

* optimize hparams and code

* tweak hp

* sync

* fix loading issue

* fix annoying off by one error

* ref logps

* cleanup

* remove unused

* fix bug when no config specified

* parametric over llama size

* strict=True

* add readme and llama2 example

* hparams 7b

* fix style

* Update README.md

* remove bad file andd 7b llama2 config

* update wandb

---------

Co-authored-by: cat-state <cat@meow>
Co-authored-by: Duy V. Phung <[email protected]>
Co-authored-by: Duy Phung <[email protected]>
  • Loading branch information
4 people committed Sep 22, 2023
1 parent 3685a75 commit ffa5ba1
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 2 deletions.
32 changes: 32 additions & 0 deletions examples/llama_nemo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
### NeMo Megatron setup:

- Install NeMo version: v1.17.0

```bash
git clone https://github.com/NVIDIA/NeMo/
cd NeMo
git checkout d3017e4
pip install -e '.[all]'
```

- Install Apex:
```bash
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
```

### Convert LLaMa to NeMo:
Example:

```bash
python convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b
```

### Training:
Example: [wandb](https://wandb.ai/carperai/trlxnemo/runs/v7592y73?workspace=user-pvduy)

```bash
sbatch dist_train.sh
```
144 changes: 144 additions & 0 deletions examples/llama_nemo/convert_llama_to_nemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# flake8: noqa

import os
from pathlib import Path

import torch
from omegaconf.omegaconf import OmegaConf
from transformers import AutoModelForCausalLM


def main(args): # noqa: C901
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
print("Loaded model")

model_state_dict = model.state_dict()
print("Model loaded")

# Constants
TOTAL_LAYERS = model.config.num_hidden_layers
TOTAL_TP = args.total_tp
HIDDEN_DIM = model.config.hidden_size
FFN_HIDDEN_DIM = model.config.intermediate_size
PART_ATTN_DIM = HIDDEN_DIM // TOTAL_TP
PART_MLP_DIM = FFN_HIDDEN_DIM // TOTAL_TP
VOCAB_SIZE = model.config.vocab_size
EMBEDDING_DIM = VOCAB_SIZE // TOTAL_TP
OUTPUT_FOLDER = args.output_folder # NeMo converted checkpoint folder with llama weights

# Model Loading

def build_layer_mapping(layer):
return {
f"model.language_model.encoder.layers.{layer}.input_layernorm.weight": f"model.layers.{layer}.input_layernorm.weight",
f"model.language_model.encoder.layers.{layer}.self_attention.query_key_value.weight": [
f"model.layers.{layer}.self_attn.q_proj.weight",
f"model.layers.{layer}.self_attn.k_proj.weight",
f"model.layers.{layer}.self_attn.v_proj.weight",
],
f"model.language_model.encoder.layers.{layer}.self_attention.dense.weight": f"model.layers.{layer}.self_attn.o_proj.weight",
f"model.language_model.encoder.layers.{layer}.post_attention_layernorm.weight": f"model.layers.{layer}.post_attention_layernorm.weight",
f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h.weight": f"model.layers.{layer}.mlp.gate_proj.weight",
f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h_2.weight": f"model.layers.{layer}.mlp.up_proj.weight",
f"model.language_model.encoder.layers.{layer}.mlp.dense_4h_to_h.weight": f"model.layers.{layer}.mlp.down_proj.weight",
}

def save_nemo_state_dict(nemo_state_dict, tp_idx):
if TOTAL_TP == 1:
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/model_weights.ckpt")
else:
os.makedirs(f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}", exist_ok=True)
torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}/model_weights.ckpt")

def map_weights(tp_idx):
nemo_state_dict = {}

# Word embeddings mapping

nemo_state_dict["model.language_model.embedding.word_embeddings.weight"] = model_state_dict[
"model.embed_tokens.weight"
][tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, :]

nemo_state_dict["model.language_model.encoder.final_layernorm.weight"] = model_state_dict["model.norm.weight"]

nemo_state_dict["model.language_model.output_layer.weight"] = model_state_dict["lm_head.weight"][
tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, :
]

# Other layer mappings
for layer in range(TOTAL_LAYERS):
layer_mapping = build_layer_mapping(layer)
for k in layer_mapping.keys():
# original_size = nemo_state_dict[k].shape
if "self_attention.query_key_value.weight" in k:
nemo_state_dict[k] = get_self_attention_weight(model_state_dict, layer_mapping, k, tp_idx)
elif "self_attention.dense.weight" in k:
nemo_state_dict[k] = model_state_dict[layer_mapping[k]][
:, tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM
]
elif "mlp.dense_h_to_4h.weight" in k or "mlp.dense_h_to_4h_2.weight" in k:
nemo_state_dict[k] = get_mlp_weight(model_state_dict, layer_mapping, k, tp_idx)
elif "mlp.dense_4h_to_h.weight" in k:
nemo_state_dict[k] = model_state_dict[layer_mapping[k]][
:, tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM
]
elif isinstance(layer_mapping[k], torch.Tensor):
nemo_state_dict[k] = layer_mapping[k]
else:
nemo_state_dict[k] = model_state_dict[layer_mapping[k]]

# break view relationships otherwise pytorch will save original weights
# to back the slices
nemo_state_dict = {k: v.clone() for k, v in nemo_state_dict.items()}
save_nemo_state_dict(nemo_state_dict, tp_idx)

def get_self_attention_weight(model_state_dict, layer_mapping, key, tp_idx):
llama_query = model_state_dict[layer_mapping[key][0]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :]
llama_key = model_state_dict[layer_mapping[key][1]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :]
llama_value = model_state_dict[layer_mapping[key][2]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :]
return torch.cat([llama_query, llama_key, llama_value], dim=0)

def get_mlp_weight(model_state_dict, layer_mapping, key, tp_idx):
llama_weight = model_state_dict[layer_mapping[key]]
return llama_weight[tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM, :]

# dummy config

megatron_cfg_path = Path(__file__).parent / "megatron_7b_llama.yaml"

megatron_cfg = OmegaConf.load(megatron_cfg_path)
megatron_cfg.name = f"megatron_{args.name}"
megatron_cfg.trainer.num_nodes = 1
megatron_cfg.trainer.devices = TOTAL_TP
megatron_cfg.model.tensor_model_parallel_size = TOTAL_TP
megatron_cfg.model.padded_vocab_size = model.config.vocab_size
megatron_cfg.model.hidden_size = model.config.hidden_size
megatron_cfg.model.ffn_hidden_size = model.config.intermediate_size
megatron_cfg.model.num_layers = model.config.num_hidden_layers
megatron_cfg.model.num_attention_heads = model.config.num_attention_heads
megatron_cfg.model.max_position_embeddings = model.config.max_position_embeddings
megatron_cfg.model.seq_length = model.config.max_position_embeddings

megatron_cfg.exp_manager.create_wandb_logger = False
megatron_cfg.exp_manager.create_checkpoint_callback = False

print("Mapping weights")
for tp in range(TOTAL_TP):
map_weights(tp)

OmegaConf.save(megatron_cfg, str(Path(OUTPUT_FOLDER) / f"megatron_{args.name}.yaml"))

print("Done")


if __name__ == "__main__":
from argparse import ArgumentParser

args = ArgumentParser()
args.add_argument("--model_path", type=str, required=True)
args.add_argument("--output_folder", type=str, required=True)
args.add_argument("--total_tp", type=int, required=True)
args.add_argument("--name", type=str, required=True)
main(args.parse_args())
13 changes: 13 additions & 0 deletions examples/llama_nemo/dist_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
#SBATCH --job-name=llama
#SBATCH --partition=g40
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --mem=0
#SBATCH --cpus-per-task=8
#SBATCH --output=out.txt
#SBATCH --error=error.txt
#SBATCH --exclusive

cd examples/llama_nemo
srun --label python nemo_llama2_ppo_sentiments.py
115 changes: 115 additions & 0 deletions examples/llama_nemo/nemo_llama2_ppo_sentiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

from datasets import load_dataset
from transformers import DistilBertForSequenceClassification, pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def load_nemo_config():
"""Load nemo-megatron-1.3b model and trainer config"""
# Import here to not require nemo as a dependency
from omegaconf import OmegaConf

return OmegaConf.load("nemo_llama2_7b/megatron_7b.yaml")


def main(hparams={}):
# Merge sweep config with default config if given
default_config = TRLConfig.update(default_ppo_config().to_dict(), hparams)
nemo_config = load_nemo_config()
print(nemo_config)
cfg_name = "llama2-7b"
config = default_config.evolve(
train=dict(
total_steps=1600,
seq_length=256,
batch_size=16,
epochs=100,
eval_interval=100,
trainer="NeMoPPOTrainer",
trainer_kwargs=dict(
pretrained_model="nemo_llama2_7b/",
megatron_cfg=nemo_config,
),
checkpoint_interval=256,
checkpoint_dir=f"nemo_{cfg_name}_ppo_sentiments",
seed=2023,
project_name="trlxnemo",
tags=["nemo", "ppo", "sentiments", cfg_name],
),
optimizer=dict(
name="adamw",
kwargs=dict(
lr=1e-5,
weight_decay=1e-06,
eps=1.0e-8,
betas=(0.9, 0.95),
),
),
scheduler=dict(
name="CosineAnnealing",
),
model=dict(num_layers_unfrozen=24),
method=dict(
num_rollouts=128,
init_kl_coef=0.05,
vf_coef=1,
scale_reward="ignored",
gamma=1,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
gen_kwargs=dict(temperature=1.0, max_new_tokens=64),
chunk_size=64,
ppo_epochs=4,
),
)
config.scheduler.kwargs = dict(warmup_steps=0, constant_steps=1e12, min_lr=1e-6)

rank = int(os.environ["SLURM_PROCID"])
local_rank = rank % 8

reward_model = DistilBertForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb")
reward_model.to("cpu")
sentiment_fn = pipeline(
"sentiment-analysis",
model=reward_model, # "lvwerra/distilbert-imdb",
tokenizer="lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=local_rank,
)

def reward_fn(samples: List[str], **kwargs) -> List[float]:
reward_model.to(local_rank)
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
reward_model.to("cpu")
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
30 changes: 28 additions & 2 deletions trlx/models/modeling_nemo_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
MegatronGPTModel,
)
from nemo.collections.nlp.modules.common.megatron.attention import ParallelAttention
from nemo.collections.nlp.modules.common.megatron.module import (
Float16Module,
MegatronModule,
Expand Down Expand Up @@ -58,6 +59,11 @@
_PER_DP_RANK_RNG = "per-data-parallel-rank-rng"


def patch_attention_for_llama(m):
if isinstance(m, ParallelAttention):
m.megatron_legacy = True


class ParallelLinear(nn.Module):
"""Linear layer parallelized over the longer dimension."""

Expand Down Expand Up @@ -181,7 +187,13 @@ def __init__(self, language_model, other_heads, build_reference_model=True):
self.reference_model_offloaded = True

self.other_heads = other_heads
self.word_embeddings = language_model.word_embeddings
if hasattr(language_model, "output_layer"):
self.output_layer = self._lm.language_model.output_layer
self.word_embeddings = self.output_layer.weight
else:
if hasattr(language_model, "word_embeddings"):
self.word_embeddings = language_model.word_embeddings
self.output_layer = None

# The tensor from the previous pipeline rank arrives via this method
def set_input_tensor(self, input_tensor):
Expand All @@ -194,6 +206,15 @@ def load_state_dict(self, lm_state_dict, strict=True):
"""Load GPTModel state dict."""
self.language_model.load_state_dict(lm_state_dict, strict=strict)

if "output_layer.weight" in lm_state_dict:
dtype = lm_state_dict["output_layer.weight"].dtype
device = self.language_model.output_layer.weight.device
params = torch.nn.Parameter(
lm_state_dict["output_layer.weight"].to(device, dtype=dtype), requires_grad=True
)
self.language_model.output_layer.weight = params
print("Loaded output_layer.weight from lm_state_dict")

if self.build_reference_model:
for p in self.language_model.parameters():
if p.requires_grad:
Expand Down Expand Up @@ -232,7 +253,10 @@ def forward(
run_value_head=False,
**kwargs,
):
logit_weights = self._lm.word_embeddings_weight()
if hasattr(self._lm.language_model, "output_layer"):
logit_weights = self._lm.language_model.output_layer.weight
else:
logit_weights = self._lm.word_embeddings_weight()

if run_policy_model:
self.offload_reference_model()
Expand Down Expand Up @@ -501,6 +525,8 @@ def freeze_layers(m):
p.requires_grad_(False)
gpt.language_model.apply(freeze_layers)

if self.cfg.get("megatron_legacy", False):
gpt.apply(patch_attention_for_llama)
# If running on the last pipeline stage, add the PPO value head and hydra reference model
if post_process:
value_head = ValueHead(self.cfg.hidden_size, self.cfg.sequence_parallel)
Expand Down
8 changes: 8 additions & 0 deletions trlx/models/modeling_nemo_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def trim_key(key, prefix):
unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True)
print(f"Loaded from pretrained {rank_params}")

def model_provider_func(self, *args, **kwargs):
gpt = super().model_provider_func(*args, **kwargs)

from trlx.models.modeling_nemo_ppo import patch_attention_for_llama

gpt.apply(patch_attention_for_llama)
return gpt

# Adapted from NeMo
# https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259
def training_step(self, batch: List[torch.Tensor], batch_idx: int): # noqa: C901
Expand Down

0 comments on commit ffa5ba1

Please sign in to comment.