-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init" * fix logging bug * sync * optimize hparams and code * tweak hp * sync * fix loading issue * fix annoying off by one error * ref logps * cleanup * remove unused * fix bug when no config specified * parametric over llama size * strict=True * add readme and llama2 example * hparams 7b * fix style * Update README.md * remove bad file andd 7b llama2 config * update wandb --------- Co-authored-by: cat-state <cat@meow> Co-authored-by: Duy V. Phung <[email protected]> Co-authored-by: Duy Phung <[email protected]>
- Loading branch information
1 parent
3685a75
commit ffa5ba1
Showing
6 changed files
with
340 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
### NeMo Megatron setup: | ||
|
||
- Install NeMo version: v1.17.0 | ||
|
||
```bash | ||
git clone https://github.com/NVIDIA/NeMo/ | ||
cd NeMo | ||
git checkout d3017e4 | ||
pip install -e '.[all]' | ||
``` | ||
|
||
- Install Apex: | ||
```bash | ||
git clone https://github.com/NVIDIA/apex | ||
cd apex | ||
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... | ||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ | ||
``` | ||
|
||
### Convert LLaMa to NeMo: | ||
Example: | ||
|
||
```bash | ||
python convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b | ||
``` | ||
|
||
### Training: | ||
Example: [wandb](https://wandb.ai/carperai/trlxnemo/runs/v7592y73?workspace=user-pvduy) | ||
|
||
```bash | ||
sbatch dist_train.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# flake8: noqa | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from omegaconf.omegaconf import OmegaConf | ||
from transformers import AutoModelForCausalLM | ||
|
||
|
||
def main(args): # noqa: C901 | ||
print("Loading model...") | ||
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) | ||
print("Loaded model") | ||
|
||
model_state_dict = model.state_dict() | ||
print("Model loaded") | ||
|
||
# Constants | ||
TOTAL_LAYERS = model.config.num_hidden_layers | ||
TOTAL_TP = args.total_tp | ||
HIDDEN_DIM = model.config.hidden_size | ||
FFN_HIDDEN_DIM = model.config.intermediate_size | ||
PART_ATTN_DIM = HIDDEN_DIM // TOTAL_TP | ||
PART_MLP_DIM = FFN_HIDDEN_DIM // TOTAL_TP | ||
VOCAB_SIZE = model.config.vocab_size | ||
EMBEDDING_DIM = VOCAB_SIZE // TOTAL_TP | ||
OUTPUT_FOLDER = args.output_folder # NeMo converted checkpoint folder with llama weights | ||
|
||
# Model Loading | ||
|
||
def build_layer_mapping(layer): | ||
return { | ||
f"model.language_model.encoder.layers.{layer}.input_layernorm.weight": f"model.layers.{layer}.input_layernorm.weight", | ||
f"model.language_model.encoder.layers.{layer}.self_attention.query_key_value.weight": [ | ||
f"model.layers.{layer}.self_attn.q_proj.weight", | ||
f"model.layers.{layer}.self_attn.k_proj.weight", | ||
f"model.layers.{layer}.self_attn.v_proj.weight", | ||
], | ||
f"model.language_model.encoder.layers.{layer}.self_attention.dense.weight": f"model.layers.{layer}.self_attn.o_proj.weight", | ||
f"model.language_model.encoder.layers.{layer}.post_attention_layernorm.weight": f"model.layers.{layer}.post_attention_layernorm.weight", | ||
f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h.weight": f"model.layers.{layer}.mlp.gate_proj.weight", | ||
f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h_2.weight": f"model.layers.{layer}.mlp.up_proj.weight", | ||
f"model.language_model.encoder.layers.{layer}.mlp.dense_4h_to_h.weight": f"model.layers.{layer}.mlp.down_proj.weight", | ||
} | ||
|
||
def save_nemo_state_dict(nemo_state_dict, tp_idx): | ||
if TOTAL_TP == 1: | ||
os.makedirs(OUTPUT_FOLDER, exist_ok=True) | ||
torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/model_weights.ckpt") | ||
else: | ||
os.makedirs(f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}", exist_ok=True) | ||
torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}/model_weights.ckpt") | ||
|
||
def map_weights(tp_idx): | ||
nemo_state_dict = {} | ||
|
||
# Word embeddings mapping | ||
|
||
nemo_state_dict["model.language_model.embedding.word_embeddings.weight"] = model_state_dict[ | ||
"model.embed_tokens.weight" | ||
][tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, :] | ||
|
||
nemo_state_dict["model.language_model.encoder.final_layernorm.weight"] = model_state_dict["model.norm.weight"] | ||
|
||
nemo_state_dict["model.language_model.output_layer.weight"] = model_state_dict["lm_head.weight"][ | ||
tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, : | ||
] | ||
|
||
# Other layer mappings | ||
for layer in range(TOTAL_LAYERS): | ||
layer_mapping = build_layer_mapping(layer) | ||
for k in layer_mapping.keys(): | ||
# original_size = nemo_state_dict[k].shape | ||
if "self_attention.query_key_value.weight" in k: | ||
nemo_state_dict[k] = get_self_attention_weight(model_state_dict, layer_mapping, k, tp_idx) | ||
elif "self_attention.dense.weight" in k: | ||
nemo_state_dict[k] = model_state_dict[layer_mapping[k]][ | ||
:, tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM | ||
] | ||
elif "mlp.dense_h_to_4h.weight" in k or "mlp.dense_h_to_4h_2.weight" in k: | ||
nemo_state_dict[k] = get_mlp_weight(model_state_dict, layer_mapping, k, tp_idx) | ||
elif "mlp.dense_4h_to_h.weight" in k: | ||
nemo_state_dict[k] = model_state_dict[layer_mapping[k]][ | ||
:, tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM | ||
] | ||
elif isinstance(layer_mapping[k], torch.Tensor): | ||
nemo_state_dict[k] = layer_mapping[k] | ||
else: | ||
nemo_state_dict[k] = model_state_dict[layer_mapping[k]] | ||
|
||
# break view relationships otherwise pytorch will save original weights | ||
# to back the slices | ||
nemo_state_dict = {k: v.clone() for k, v in nemo_state_dict.items()} | ||
save_nemo_state_dict(nemo_state_dict, tp_idx) | ||
|
||
def get_self_attention_weight(model_state_dict, layer_mapping, key, tp_idx): | ||
llama_query = model_state_dict[layer_mapping[key][0]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :] | ||
llama_key = model_state_dict[layer_mapping[key][1]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :] | ||
llama_value = model_state_dict[layer_mapping[key][2]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :] | ||
return torch.cat([llama_query, llama_key, llama_value], dim=0) | ||
|
||
def get_mlp_weight(model_state_dict, layer_mapping, key, tp_idx): | ||
llama_weight = model_state_dict[layer_mapping[key]] | ||
return llama_weight[tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM, :] | ||
|
||
# dummy config | ||
|
||
megatron_cfg_path = Path(__file__).parent / "megatron_7b_llama.yaml" | ||
|
||
megatron_cfg = OmegaConf.load(megatron_cfg_path) | ||
megatron_cfg.name = f"megatron_{args.name}" | ||
megatron_cfg.trainer.num_nodes = 1 | ||
megatron_cfg.trainer.devices = TOTAL_TP | ||
megatron_cfg.model.tensor_model_parallel_size = TOTAL_TP | ||
megatron_cfg.model.padded_vocab_size = model.config.vocab_size | ||
megatron_cfg.model.hidden_size = model.config.hidden_size | ||
megatron_cfg.model.ffn_hidden_size = model.config.intermediate_size | ||
megatron_cfg.model.num_layers = model.config.num_hidden_layers | ||
megatron_cfg.model.num_attention_heads = model.config.num_attention_heads | ||
megatron_cfg.model.max_position_embeddings = model.config.max_position_embeddings | ||
megatron_cfg.model.seq_length = model.config.max_position_embeddings | ||
|
||
megatron_cfg.exp_manager.create_wandb_logger = False | ||
megatron_cfg.exp_manager.create_checkpoint_callback = False | ||
|
||
print("Mapping weights") | ||
for tp in range(TOTAL_TP): | ||
map_weights(tp) | ||
|
||
OmegaConf.save(megatron_cfg, str(Path(OUTPUT_FOLDER) / f"megatron_{args.name}.yaml")) | ||
|
||
print("Done") | ||
|
||
|
||
if __name__ == "__main__": | ||
from argparse import ArgumentParser | ||
|
||
args = ArgumentParser() | ||
args.add_argument("--model_path", type=str, required=True) | ||
args.add_argument("--output_folder", type=str, required=True) | ||
args.add_argument("--total_tp", type=int, required=True) | ||
args.add_argument("--name", type=str, required=True) | ||
main(args.parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/bin/bash | ||
#SBATCH --job-name=llama | ||
#SBATCH --partition=g40 | ||
#SBATCH --nodes=1 | ||
#SBATCH --ntasks-per-node=8 | ||
#SBATCH --mem=0 | ||
#SBATCH --cpus-per-task=8 | ||
#SBATCH --output=out.txt | ||
#SBATCH --error=error.txt | ||
#SBATCH --exclusive | ||
|
||
cd examples/llama_nemo | ||
srun --label python nemo_llama2_ppo_sentiments.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset | ||
# with a sentiment reward function | ||
import json | ||
import os | ||
import sys | ||
from typing import List | ||
|
||
from datasets import load_dataset | ||
from transformers import DistilBertForSequenceClassification, pipeline | ||
|
||
import trlx | ||
from trlx.data.default_configs import TRLConfig, default_ppo_config | ||
|
||
|
||
def get_positive_score(scores): | ||
"Extract value associated with a positive sentiment from pipeline's output" | ||
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] | ||
|
||
|
||
def load_nemo_config(): | ||
"""Load nemo-megatron-1.3b model and trainer config""" | ||
# Import here to not require nemo as a dependency | ||
from omegaconf import OmegaConf | ||
|
||
return OmegaConf.load("nemo_llama2_7b/megatron_7b.yaml") | ||
|
||
|
||
def main(hparams={}): | ||
# Merge sweep config with default config if given | ||
default_config = TRLConfig.update(default_ppo_config().to_dict(), hparams) | ||
nemo_config = load_nemo_config() | ||
print(nemo_config) | ||
cfg_name = "llama2-7b" | ||
config = default_config.evolve( | ||
train=dict( | ||
total_steps=1600, | ||
seq_length=256, | ||
batch_size=16, | ||
epochs=100, | ||
eval_interval=100, | ||
trainer="NeMoPPOTrainer", | ||
trainer_kwargs=dict( | ||
pretrained_model="nemo_llama2_7b/", | ||
megatron_cfg=nemo_config, | ||
), | ||
checkpoint_interval=256, | ||
checkpoint_dir=f"nemo_{cfg_name}_ppo_sentiments", | ||
seed=2023, | ||
project_name="trlxnemo", | ||
tags=["nemo", "ppo", "sentiments", cfg_name], | ||
), | ||
optimizer=dict( | ||
name="adamw", | ||
kwargs=dict( | ||
lr=1e-5, | ||
weight_decay=1e-06, | ||
eps=1.0e-8, | ||
betas=(0.9, 0.95), | ||
), | ||
), | ||
scheduler=dict( | ||
name="CosineAnnealing", | ||
), | ||
model=dict(num_layers_unfrozen=24), | ||
method=dict( | ||
num_rollouts=128, | ||
init_kl_coef=0.05, | ||
vf_coef=1, | ||
scale_reward="ignored", | ||
gamma=1, | ||
lam=0.95, | ||
cliprange=0.2, | ||
cliprange_value=0.2, | ||
gen_kwargs=dict(temperature=1.0, max_new_tokens=64), | ||
chunk_size=64, | ||
ppo_epochs=4, | ||
), | ||
) | ||
config.scheduler.kwargs = dict(warmup_steps=0, constant_steps=1e12, min_lr=1e-6) | ||
|
||
rank = int(os.environ["SLURM_PROCID"]) | ||
local_rank = rank % 8 | ||
|
||
reward_model = DistilBertForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb") | ||
reward_model.to("cpu") | ||
sentiment_fn = pipeline( | ||
"sentiment-analysis", | ||
model=reward_model, # "lvwerra/distilbert-imdb", | ||
tokenizer="lvwerra/distilbert-imdb", | ||
top_k=2, | ||
truncation=True, | ||
batch_size=256, | ||
device=local_rank, | ||
) | ||
|
||
def reward_fn(samples: List[str], **kwargs) -> List[float]: | ||
reward_model.to(local_rank) | ||
sentiments = list(map(get_positive_score, sentiment_fn(samples))) | ||
reward_model.to("cpu") | ||
return sentiments | ||
|
||
# Take few words off of movies reviews as prompts | ||
imdb = load_dataset("imdb", split="train+test") | ||
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] | ||
trlx.train( | ||
reward_fn=reward_fn, | ||
prompts=prompts, | ||
eval_prompts=["I don't know much about Hungarian underground"] * 256, | ||
config=config, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) | ||
main(hparams) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters