diff --git a/examples/llama_nemo/README.md b/examples/llama_nemo/README.md new file mode 100644 index 000000000..a09732aeb --- /dev/null +++ b/examples/llama_nemo/README.md @@ -0,0 +1,32 @@ +### NeMo Megatron setup: + +- Install NeMo version: v1.17.0 + +```bash +git clone https://github.com/NVIDIA/NeMo/ +cd NeMo +git checkout d3017e4 +pip install -e '.[all]' +``` + +- Install Apex: +```bash +git clone https://github.com/NVIDIA/apex +cd apex +# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ +``` + +### Convert LLaMa to NeMo: +Example: + +```bash +python convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b +``` + +### Training: +Example: [wandb](https://wandb.ai/carperai/trlxnemo/runs/v7592y73?workspace=user-pvduy) + +```bash +sbatch dist_train.sh +``` diff --git a/examples/llama_nemo/convert_llama_to_nemo.py b/examples/llama_nemo/convert_llama_to_nemo.py new file mode 100644 index 000000000..bded2950e --- /dev/null +++ b/examples/llama_nemo/convert_llama_to_nemo.py @@ -0,0 +1,144 @@ +# flake8: noqa + +import os +from pathlib import Path + +import torch +from omegaconf.omegaconf import OmegaConf +from transformers import AutoModelForCausalLM + + +def main(args): # noqa: C901 + print("Loading model...") + model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) + print("Loaded model") + + model_state_dict = model.state_dict() + print("Model loaded") + + # Constants + TOTAL_LAYERS = model.config.num_hidden_layers + TOTAL_TP = args.total_tp + HIDDEN_DIM = model.config.hidden_size + FFN_HIDDEN_DIM = model.config.intermediate_size + PART_ATTN_DIM = HIDDEN_DIM // TOTAL_TP + PART_MLP_DIM = FFN_HIDDEN_DIM // TOTAL_TP + VOCAB_SIZE = model.config.vocab_size + EMBEDDING_DIM = VOCAB_SIZE // TOTAL_TP + OUTPUT_FOLDER = args.output_folder # NeMo converted checkpoint folder with llama weights + + # Model Loading + + def build_layer_mapping(layer): + return { + f"model.language_model.encoder.layers.{layer}.input_layernorm.weight": f"model.layers.{layer}.input_layernorm.weight", + f"model.language_model.encoder.layers.{layer}.self_attention.query_key_value.weight": [ + f"model.layers.{layer}.self_attn.q_proj.weight", + f"model.layers.{layer}.self_attn.k_proj.weight", + f"model.layers.{layer}.self_attn.v_proj.weight", + ], + f"model.language_model.encoder.layers.{layer}.self_attention.dense.weight": f"model.layers.{layer}.self_attn.o_proj.weight", + f"model.language_model.encoder.layers.{layer}.post_attention_layernorm.weight": f"model.layers.{layer}.post_attention_layernorm.weight", + f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h.weight": f"model.layers.{layer}.mlp.gate_proj.weight", + f"model.language_model.encoder.layers.{layer}.mlp.dense_h_to_4h_2.weight": f"model.layers.{layer}.mlp.up_proj.weight", + f"model.language_model.encoder.layers.{layer}.mlp.dense_4h_to_h.weight": f"model.layers.{layer}.mlp.down_proj.weight", + } + + def save_nemo_state_dict(nemo_state_dict, tp_idx): + if TOTAL_TP == 1: + os.makedirs(OUTPUT_FOLDER, exist_ok=True) + torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/model_weights.ckpt") + else: + os.makedirs(f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}", exist_ok=True) + torch.save(nemo_state_dict, f"{OUTPUT_FOLDER}/mp_rank_0{tp_idx}/model_weights.ckpt") + + def map_weights(tp_idx): + nemo_state_dict = {} + + # Word embeddings mapping + + nemo_state_dict["model.language_model.embedding.word_embeddings.weight"] = model_state_dict[ + "model.embed_tokens.weight" + ][tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, :] + + nemo_state_dict["model.language_model.encoder.final_layernorm.weight"] = model_state_dict["model.norm.weight"] + + nemo_state_dict["model.language_model.output_layer.weight"] = model_state_dict["lm_head.weight"][ + tp_idx * EMBEDDING_DIM : (tp_idx + 1) * EMBEDDING_DIM, : + ] + + # Other layer mappings + for layer in range(TOTAL_LAYERS): + layer_mapping = build_layer_mapping(layer) + for k in layer_mapping.keys(): + # original_size = nemo_state_dict[k].shape + if "self_attention.query_key_value.weight" in k: + nemo_state_dict[k] = get_self_attention_weight(model_state_dict, layer_mapping, k, tp_idx) + elif "self_attention.dense.weight" in k: + nemo_state_dict[k] = model_state_dict[layer_mapping[k]][ + :, tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM + ] + elif "mlp.dense_h_to_4h.weight" in k or "mlp.dense_h_to_4h_2.weight" in k: + nemo_state_dict[k] = get_mlp_weight(model_state_dict, layer_mapping, k, tp_idx) + elif "mlp.dense_4h_to_h.weight" in k: + nemo_state_dict[k] = model_state_dict[layer_mapping[k]][ + :, tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM + ] + elif isinstance(layer_mapping[k], torch.Tensor): + nemo_state_dict[k] = layer_mapping[k] + else: + nemo_state_dict[k] = model_state_dict[layer_mapping[k]] + + # break view relationships otherwise pytorch will save original weights + # to back the slices + nemo_state_dict = {k: v.clone() for k, v in nemo_state_dict.items()} + save_nemo_state_dict(nemo_state_dict, tp_idx) + + def get_self_attention_weight(model_state_dict, layer_mapping, key, tp_idx): + llama_query = model_state_dict[layer_mapping[key][0]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :] + llama_key = model_state_dict[layer_mapping[key][1]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :] + llama_value = model_state_dict[layer_mapping[key][2]][tp_idx * PART_ATTN_DIM : (tp_idx + 1) * PART_ATTN_DIM, :] + return torch.cat([llama_query, llama_key, llama_value], dim=0) + + def get_mlp_weight(model_state_dict, layer_mapping, key, tp_idx): + llama_weight = model_state_dict[layer_mapping[key]] + return llama_weight[tp_idx * PART_MLP_DIM : (tp_idx + 1) * PART_MLP_DIM, :] + + # dummy config + + megatron_cfg_path = Path(__file__).parent / "megatron_7b_llama.yaml" + + megatron_cfg = OmegaConf.load(megatron_cfg_path) + megatron_cfg.name = f"megatron_{args.name}" + megatron_cfg.trainer.num_nodes = 1 + megatron_cfg.trainer.devices = TOTAL_TP + megatron_cfg.model.tensor_model_parallel_size = TOTAL_TP + megatron_cfg.model.padded_vocab_size = model.config.vocab_size + megatron_cfg.model.hidden_size = model.config.hidden_size + megatron_cfg.model.ffn_hidden_size = model.config.intermediate_size + megatron_cfg.model.num_layers = model.config.num_hidden_layers + megatron_cfg.model.num_attention_heads = model.config.num_attention_heads + megatron_cfg.model.max_position_embeddings = model.config.max_position_embeddings + megatron_cfg.model.seq_length = model.config.max_position_embeddings + + megatron_cfg.exp_manager.create_wandb_logger = False + megatron_cfg.exp_manager.create_checkpoint_callback = False + + print("Mapping weights") + for tp in range(TOTAL_TP): + map_weights(tp) + + OmegaConf.save(megatron_cfg, str(Path(OUTPUT_FOLDER) / f"megatron_{args.name}.yaml")) + + print("Done") + + +if __name__ == "__main__": + from argparse import ArgumentParser + + args = ArgumentParser() + args.add_argument("--model_path", type=str, required=True) + args.add_argument("--output_folder", type=str, required=True) + args.add_argument("--total_tp", type=int, required=True) + args.add_argument("--name", type=str, required=True) + main(args.parse_args()) diff --git a/examples/llama_nemo/dist_train.sh b/examples/llama_nemo/dist_train.sh new file mode 100755 index 000000000..feb2fa887 --- /dev/null +++ b/examples/llama_nemo/dist_train.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --job-name=llama +#SBATCH --partition=g40 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --mem=0 +#SBATCH --cpus-per-task=8 +#SBATCH --output=out.txt +#SBATCH --error=error.txt +#SBATCH --exclusive + +cd examples/llama_nemo +srun --label python nemo_llama2_ppo_sentiments.py diff --git a/examples/llama_nemo/nemo_llama2_ppo_sentiments.py b/examples/llama_nemo/nemo_llama2_ppo_sentiments.py new file mode 100644 index 000000000..efe9d56c2 --- /dev/null +++ b/examples/llama_nemo/nemo_llama2_ppo_sentiments.py @@ -0,0 +1,115 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +from datasets import load_dataset +from transformers import DistilBertForSequenceClassification, pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def load_nemo_config(): + """Load nemo-megatron-1.3b model and trainer config""" + # Import here to not require nemo as a dependency + from omegaconf import OmegaConf + + return OmegaConf.load("nemo_llama2_7b/megatron_7b.yaml") + + +def main(hparams={}): + # Merge sweep config with default config if given + default_config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + nemo_config = load_nemo_config() + print(nemo_config) + cfg_name = "llama2-7b" + config = default_config.evolve( + train=dict( + total_steps=1600, + seq_length=256, + batch_size=16, + epochs=100, + eval_interval=100, + trainer="NeMoPPOTrainer", + trainer_kwargs=dict( + pretrained_model="nemo_llama2_7b/", + megatron_cfg=nemo_config, + ), + checkpoint_interval=256, + checkpoint_dir=f"nemo_{cfg_name}_ppo_sentiments", + seed=2023, + project_name="trlxnemo", + tags=["nemo", "ppo", "sentiments", cfg_name], + ), + optimizer=dict( + name="adamw", + kwargs=dict( + lr=1e-5, + weight_decay=1e-06, + eps=1.0e-8, + betas=(0.9, 0.95), + ), + ), + scheduler=dict( + name="CosineAnnealing", + ), + model=dict(num_layers_unfrozen=24), + method=dict( + num_rollouts=128, + init_kl_coef=0.05, + vf_coef=1, + scale_reward="ignored", + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + gen_kwargs=dict(temperature=1.0, max_new_tokens=64), + chunk_size=64, + ppo_epochs=4, + ), + ) + config.scheduler.kwargs = dict(warmup_steps=0, constant_steps=1e12, min_lr=1e-6) + + rank = int(os.environ["SLURM_PROCID"]) + local_rank = rank % 8 + + reward_model = DistilBertForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb") + reward_model.to("cpu") + sentiment_fn = pipeline( + "sentiment-analysis", + model=reward_model, # "lvwerra/distilbert-imdb", + tokenizer="lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=local_rank, + ) + + def reward_fn(samples: List[str], **kwargs) -> List[float]: + reward_model.to(local_rank) + sentiments = list(map(get_positive_score, sentiment_fn(samples))) + reward_model.to("cpu") + return sentiments + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/models/modeling_nemo_ppo.py b/trlx/models/modeling_nemo_ppo.py index 17dee85c0..5ad044130 100644 --- a/trlx/models/modeling_nemo_ppo.py +++ b/trlx/models/modeling_nemo_ppo.py @@ -28,6 +28,7 @@ from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( MegatronGPTModel, ) +from nemo.collections.nlp.modules.common.megatron.attention import ParallelAttention from nemo.collections.nlp.modules.common.megatron.module import ( Float16Module, MegatronModule, @@ -58,6 +59,11 @@ _PER_DP_RANK_RNG = "per-data-parallel-rank-rng" +def patch_attention_for_llama(m): + if isinstance(m, ParallelAttention): + m.megatron_legacy = True + + class ParallelLinear(nn.Module): """Linear layer parallelized over the longer dimension.""" @@ -181,7 +187,13 @@ def __init__(self, language_model, other_heads, build_reference_model=True): self.reference_model_offloaded = True self.other_heads = other_heads - self.word_embeddings = language_model.word_embeddings + if hasattr(language_model, "output_layer"): + self.output_layer = self._lm.language_model.output_layer + self.word_embeddings = self.output_layer.weight + else: + if hasattr(language_model, "word_embeddings"): + self.word_embeddings = language_model.word_embeddings + self.output_layer = None # The tensor from the previous pipeline rank arrives via this method def set_input_tensor(self, input_tensor): @@ -194,6 +206,15 @@ def load_state_dict(self, lm_state_dict, strict=True): """Load GPTModel state dict.""" self.language_model.load_state_dict(lm_state_dict, strict=strict) + if "output_layer.weight" in lm_state_dict: + dtype = lm_state_dict["output_layer.weight"].dtype + device = self.language_model.output_layer.weight.device + params = torch.nn.Parameter( + lm_state_dict["output_layer.weight"].to(device, dtype=dtype), requires_grad=True + ) + self.language_model.output_layer.weight = params + print("Loaded output_layer.weight from lm_state_dict") + if self.build_reference_model: for p in self.language_model.parameters(): if p.requires_grad: @@ -232,7 +253,10 @@ def forward( run_value_head=False, **kwargs, ): - logit_weights = self._lm.word_embeddings_weight() + if hasattr(self._lm.language_model, "output_layer"): + logit_weights = self._lm.language_model.output_layer.weight + else: + logit_weights = self._lm.word_embeddings_weight() if run_policy_model: self.offload_reference_model() @@ -501,6 +525,8 @@ def freeze_layers(m): p.requires_grad_(False) gpt.language_model.apply(freeze_layers) + if self.cfg.get("megatron_legacy", False): + gpt.apply(patch_attention_for_llama) # If running on the last pipeline stage, add the PPO value head and hydra reference model if post_process: value_head = ValueHead(self.cfg.hidden_size, self.cfg.sequence_parallel) diff --git a/trlx/models/modeling_nemo_sft.py b/trlx/models/modeling_nemo_sft.py index 10e2e28d5..a3fa70fe3 100644 --- a/trlx/models/modeling_nemo_sft.py +++ b/trlx/models/modeling_nemo_sft.py @@ -125,6 +125,14 @@ def trim_key(key, prefix): unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True) print(f"Loaded from pretrained {rank_params}") + def model_provider_func(self, *args, **kwargs): + gpt = super().model_provider_func(*args, **kwargs) + + from trlx.models.modeling_nemo_ppo import patch_attention_for_llama + + gpt.apply(patch_attention_for_llama) + return gpt + # Adapted from NeMo # https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259 def training_step(self, batch: List[torch.Tensor], batch_idx: int): # noqa: C901