From 94924de5457ff77fdb51921c03c0edb952be07f6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 25 Jan 2024 06:36:00 -0500 Subject: [PATCH 01/25] integrate tinyllama dataloaders and model args --- examples/run_axonn_amd.sh | 79 ++++++++++---- examples/run_axonn_amd_tinyllama.sh | 157 ++++++++++++++++++++++++++++ megatron/arguments.py | 2 + megatron/model/transformer.py | 2 +- megatron/training.py | 43 +++++--- pretrain_gpt.py | 55 ++++++---- 6 files changed, 283 insertions(+), 55 deletions(-) create mode 100755 examples/run_axonn_amd_tinyllama.sh diff --git a/examples/run_axonn_amd.sh b/examples/run_axonn_amd.sh index ec3bf26923..2f1ccc7899 100755 --- a/examples/run_axonn_amd.sh +++ b/examples/run_axonn_amd.sh @@ -1,4 +1,7 @@ #!/bin/bash +#SBATCH -p batch +#SBATCH -A CSC547 +#SBATCH -t 00:20:00 # Runs the "345M" parameter model @@ -7,18 +10,20 @@ module load cray-python module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses ## these lines enable CUDA aware MPI -module load craype-accel-amd-gfx90a +#module load craype-accel-amd-gfx90a export MPICH_GPU_SUPPORT_ENABLED=0 -export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib" - +#export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib" ## this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/lustre/orion/scratch/ssingh37/csc547/aws-ofi-rccl/build/lib" #export NCCL_DEBUG=INFO export FI_CXI_ATS=0 - +export HSA_FORCE_FINE_GRAIN_PCIE=1 +#export NCCL_SOCKET_IFNAME=hsn +# super important +#export NCCL_NET_GDR_LEVEL=4 +#export NCCL_P2P_LEVEL=4 ## this improves cross node bandwidth for some cases export NCCL_CROSS_NIC=1 - export CUDA_DEVICE_MAX_CONNECTIONS=1 NNODES=$SLURM_JOB_NUM_NODES @@ -31,37 +36,68 @@ export MASTER_PORT=29500 # data/checkpoint args DATA_DIR="/lustre/orion/csc547/proj-shared/parallel_deep_learning/book_corpus" - CHECKPOINT_PATH="${DATA_DIR}/checkpoints" VOCAB_FILE="${DATA_DIR}/gpt2-vocab.json" MERGE_FILE="${DATA_DIR}/gpt2-merges.txt" DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document" + ## ARCHITECTURE DETAILS +# +# +# 5B +NUM_LAYERS=24 +NUM_HEADS=32 +HIDDEN_SIZE=4096 +# +# 10B +#NUM_LAYERS=32 +#NUM_HEADS=40 +#HIDDEN_SIZE=5120 + # 20B -NUM_LAYERS=32 -HIDDEN_SIZE=7168 -NUM_HEADS=56 +#NUM_LAYERS=32 +#NUM_HEADS=56 +#HIDDEN_SIZE=7168 # 40B -NUM_LAYERS=38 -HIDDEN_SIZE=9216 -NUM_HEADS=72 +#NUM_LAYERS=38 +#NUM_HEADS=72 +#HIDDEN_SIZE=9216 +# +# 80B +#NUM_LAYERS=42 +#NUM_HEADS=96 +#HIDDEN_SIZE=12288 +# + +# 160B +#NUM_LAYERS=84 +#NUM_HEADS=96 +#HIDDEN_SIZE=12288 ## PARALLELISM DETAILS -COLUMN_TENSOR_PARR=1 +# ROW_TENSOR_PARR=2 -DEPTH_TENSOR_PARR=256 +COLUMN_TENSOR_PARR=1 +DEPTH_TENSOR_PARR=4 PIPE_PARR=1 -CACHE_LAYERS=25 +CACHE_LAYERS=24 OVERLAP=True + +MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) +DP=$(( GPUS / MP )) ## BATCH SIZES -MICRO_BATCH_SIZE=2048 -GLOBAL_BATCH_SIZE=2048 + +GLOBAL_BATCH_SIZE=16 +MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) SEQUENCE_LENGTH=2048 TRAIN_ITERS=10 + +config="r-${ROW_TENSOR_PARR}-c-${COLUMN_TENSOR_PARR}-d-${DEPTH_TENSOR_PARR}-g-${GPUS}" + GPT_ARGS=" --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ @@ -103,13 +139,17 @@ then --num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}" fi + DATA_ARGS=" --data-path $DATA_PATH \ --vocab-file $VOCAB_FILE \ --merge-file $MERGE_FILE \ - --split 949,50,1 + --split 949,50,1 \ + + --custom-dataloader " + OUTPUT_ARGS=" --log-interval 1 \ --save-interval 10000 \ @@ -117,7 +157,7 @@ OUTPUT_ARGS=" --eval-iters 1 " -SCRIPT="python -u pretrain_gpt.py \ +SCRIPT="python -u pretrain_lit_gpt.py \ $GPT_ARGS \ $DATA_ARGS \ $OUTPUT_ARGS \ @@ -128,6 +168,7 @@ SCRIPT="python -u pretrain_gpt.py \ #--load $CHECKPOINT_PATH +export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/ssingh37/csc547/lit-gpt-dev" export OMP_NUM_THREADS=7 run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ${SCRIPT}" diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh new file mode 100755 index 0000000000..43445e857b --- /dev/null +++ b/examples/run_axonn_amd_tinyllama.sh @@ -0,0 +1,157 @@ +#!/bin/bash +#SBATCH -p batch +#SBATCH -A CSC569 +#SBATCH -o /lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/logs/test.out + +echo "This TinyLLAMA script will work for <=512 GPUs." + +## loading python venv +module load cray-python +. /lustre/orion/scratch/ssingh37/csc547/venv_axonn_pt_2.1/bin/activate +module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses + +export MPICH_GPU_SUPPORT_ENABLED=0 +export FI_CXI_ATS=0 +export HSA_FORCE_FINE_GRAIN_PCIE=1 +export NCCL_CROSS_NIC=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +## point this to the AWS plugin (you should have compiled this previously) +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/lustre/orion/scratch/ssingh37/csc547/aws-ofi-rccl/build/lib" + +NNODES=$SLURM_JOB_NUM_NODES +GPUS_PER_NODE=8 ## change as per your machine +GPUS=$(( NNODES * GPUS_PER_NODE )) +export MASTER_ADDR=$(hostname) +export MASTER_PORT=29500 + + +# these are redundant for tiny-llams, so ignore +DATA_DIR="/lustre/orion/csc547/proj-shared/parallel_deep_learning/book_corpus" +VOCAB_FILE="${DATA_DIR}/gpt2-vocab.json" +MERGE_FILE="${DATA_DIR}/gpt2-merges.txt" +DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document" + +# we will save and load model checkpoints here +CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" + +#TODO: tensorboard logging +#TENSORBOARD_DIR="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/logs" +#mkdir -p ${TENSORBOARD_DIR} + +# tiny-llama1.1B +# https://github.com/azshue/lit-gpt-dev/blob/tiny-llama/lit_gpt/config.py +# +GLOBAL_BATCH_SIZE=512 +SEQUENCE_LENGTH=2048 +NUM_LAYERS=22 +NUM_HEADS=32 +HIDDEN_SIZE=2048 +FFN_HIDDEN_SIZE=5632 +NUM_QUERY_GROUPS=4 +TOKENS_IN_BILLIONS=3000 + +TRAIN_ITERS=$(( TOKENS_IN_BILLIONS * 1000000000 / GLOBAL_BATCH_SIZE / SEQUENCE_LENGTH + 100 )) +echo "Number of training iterations : ${TRAIN_ITERS}" + +## AxoNN args +## These do not affect the science +ROW_TENSOR_PARR=1 +COLUMN_TENSOR_PARR=1 +DEPTH_TENSOR_PARR=2 +PIPE_PARR=1 +CACHE_LAYERS=22 +OVERLAP=True + + +## DERIVED ARGUMENTS (ignore) +MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) +DP=$(( GPUS / MP )) +MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) + + +config="r-${ROW_TENSOR_PARR}-c-${COLUMN_TENSOR_PARR}-d-${DEPTH_TENSOR_PARR}-g-${GPUS}" + +GPT_ARGS=" + --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ + --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ + --depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \ + --pipeline-model-parallel-size ${PIPE_PARR} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_HEADS} \ + --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ + --seq-length ${SEQUENCE_LENGTH} \ + --max-position-embeddings ${SEQUENCE_LENGTH} \ + --micro-batch-size ${MICRO_BATCH_SIZE} \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --lr 4.0e-4 \ + --train-iters ${TRAIN_ITERS} \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 4.0e-5 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 2000 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --bf16 \ + --no-gradient-accumulation-fusion \ + --use-amd \ + --recompute-granularity full \ + --recompute-method uniform \ + --recompute-num-layers 1 \ + --use-flash-attn \ + --swiglu \ + --use-rotary-position-embeddings \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups ${NUM_QUERY_GROUPS} +" +# --no-gradient-accumulation-fusion is neede on AMD +# --use-amd disables features incompatible with AMD +# --swiglu makes ParallelMLP equivalent to LLAMAMLP + +if [[ $OVERLAP == "True" ]] +then + GPT_ARGS="${GPT_ARGS} \ + --overlap-axonn-comm \ + --overlap-axonn-reduce-scatter \ + --overlap-axonn-all-gather\ + --num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}" +fi + +# the data-path vocab-file and marge-file args are redundant here +# the custom-dataloader is switching to the lit gpt dataloader +DATA_ARGS=" + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 \ + --custom-dataloader +" + + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 1000 \ + --eval-interval 1000 \ + --eval-iters 100 \ +" + +SCRIPT="python -u pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH +" + + +export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/ssingh37/csc547/lit-gpt-dev" +export OMP_NUM_THREADS=7 +run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ${SCRIPT}" + +echo ${run_cmd} +eval ${run_cmd} +set +x diff --git a/megatron/arguments.py b/megatron/arguments.py index 7f975bad14..62d1ae3598 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1081,6 +1081,8 @@ def _add_validation_args(parser): def _add_data_args(parser): group = parser.add_argument_group(title='data and dataloader') + group.add_argument('--custom-dataloader', help="using custom dataloader, bypass megatron's" + "dataset/dataloader creation", action='store_true') group.add_argument('--data-path', nargs='*', default=None, help='Path to the training dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 37e225a54f..632507bd1b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -602,7 +602,7 @@ def forward(self, hidden_states, attention_mask, dim=3) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - - query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) + query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) diff --git a/megatron/training.py b/megatron/training.py index 2dddca9679..2c632ab86f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -121,22 +121,27 @@ def pretrain(train_valid_test_dataset_provider, # Data stuff. timers('train/valid/test-data-iterators-setup', log_level=0).start( barrier=True) - if args.virtual_pipeline_model_parallel_size is not None: - all_data_iterators = [ - build_train_valid_test_data_iterators( - train_valid_test_dataset_provider) - for _ in range(len(model)) - ] - train_data_iterator = [data_iterators[0] - for data_iterators in all_data_iterators] - valid_data_iterator = [data_iterators[1] - for data_iterators in all_data_iterators] - test_data_iterator = [data_iterators[2] - for data_iterators in all_data_iterators] + if not args.custom_dataloader: + if args.virtual_pipeline_model_parallel_size is not None: + all_data_iterators = [ + build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) + for _ in range(len(model)) + ] + train_data_iterator = [data_iterators[0] + for data_iterators in all_data_iterators] + valid_data_iterator = [data_iterators[1] + for data_iterators in all_data_iterators] + test_data_iterator = [data_iterators[2] + for data_iterators in all_data_iterators] + else: + train_data_iterator, valid_data_iterator, test_data_iterator \ + = build_train_valid_test_data_iterators( + train_valid_test_dataset_provider) else: - train_data_iterator, valid_data_iterator, test_data_iterator \ - = build_train_valid_test_data_iterators( - train_valid_test_dataset_provider) + assert args.virtual_pipeline_model_parallel_size is None + train_data_iterator, valid_data_iterator = train_valid_test_dataset_provider(0) + test_data_iterator = None timers('train/valid/test-data-iterators-setup').stop() print_datetime('after dataloaders are built') @@ -650,8 +655,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[nan_iters_key]) - log_string += ' theoretical FLOP/s: {:.3f} TFLOP/s | '.format(get_flops(elapsed_time_per_iteration)) - log_string += ' model size: {:.3f} B params | '.format(get_params()) + #log_string += ' theoretical FLOP/s: {:.3f} TFLOP/s | '.format(get_flops(elapsed_time_per_iteration)) + #log_string += ' model size: {:.3f} B params | '.format(get_params()) curr, peak = get_mem() log_string += ' memory used by tensors {:.3f} GB ( peak {:.3f} GB)'.format(curr, peak) @@ -677,6 +682,8 @@ def get_flops(batch_time): vocab_size = args.padded_vocab_size num_gpus = torch.distributed.get_world_size() teraflop_in_batch = 96*batch_size*seq_length*num_layers*(hidden_size**2)*(1+seq_length/(6*hidden_size)+(vocab_size)/(16*num_layers*hidden_size))/(1e12) + if args.swiglu: + teraflop_in_batch += (2*batch_size*seq_length*4*(hidden_size**2))*4*num_layers / 1e12 return teraflop_in_batch/batch_time/num_gpus @@ -688,6 +695,8 @@ def get_params(): hidden_size = args.hidden_size vocab_size = args.padded_vocab_size params = 12 * num_layers * (hidden_size ** 2)* ( 1 + 13/(12*hidden_size) + (vocab_size + seq_length)/(12 * num_layers * hidden_size)) / 1e9 + if args.swiglu: + params += num_layers * 4 * hidden_size ** 2 / 1e9 return params def get_mem(): diff --git a/pretrain_gpt.py b/pretrain_gpt.py index bdc50cc0ac..1044c21f61 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -24,6 +24,8 @@ from axonn import axonn as ax from contextlib import nullcontext +from custom_litgpt_dataloader.data_util import create_dataloaders + def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() @@ -54,6 +56,9 @@ def get_batch(data_iterator): data = next(data_iterator) else: data = None + + if args.custom_dataloader: + data = {"text": data} data_b = tensor_parallel.broadcast_data(keys, data, datatype) @@ -67,8 +72,8 @@ def get_batch(data_iterator): attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, - args.reset_position_ids, - args.reset_attention_mask, + args.reset_position_ids, # for this to work we need access to the tokenizer + args.reset_attention_mask, # for this to work we need access to the tokenizer args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids @@ -141,28 +146,42 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() - print_rank_0('> building train, validation, and test datasets ' - 'for GPT ...') - train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - data_prefix=args.data_path, - splits_string=args.split, - train_valid_test_num_samples=train_val_test_num_samples, - seq_length=args.seq_length, - seed=args.seed, - skip_warmup=(not args.mmap_warmup), - train_data_prefix=args.train_data_path, - valid_data_prefix=args.valid_data_path, - test_data_prefix=args.test_data_path, - data_cache_path=args.data_cache_path) - print_rank_0("> finished creating GPT datasets ...") + if args.custom_dataloader: + train_iterator, valid_iterator = create_dataloaders( + batch_size= args.micro_batch_size, + block_size= args.seq_length, + ) + + # these flags are set within megatron in + # the OG dataloader + args.do_train = True + args.do_valid = True + args.do_test = False - return train_ds, valid_ds, test_ds + return train_iterator, valid_iterator + else: + print_rank_0('> building train, validation, and test datasets ' + 'for GPT ...') + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds def set_device_and_init_torch_dist(): from mpi4py import MPI import os - + MPI.Init() world_rank = MPI.COMM_WORLD.Get_rank() world_size = MPI.COMM_WORLD.Get_size() From ac0debba7dca988cfd015f8b19c3791aad962a25 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 25 Jan 2024 06:43:30 -0500 Subject: [PATCH 02/25] add the litgpt dataloader --- custom_litgpt_dataloader/__init__.py | 0 custom_litgpt_dataloader/data_util.py | 93 +++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 custom_litgpt_dataloader/__init__.py create mode 100644 custom_litgpt_dataloader/data_util.py diff --git a/custom_litgpt_dataloader/__init__.py b/custom_litgpt_dataloader/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/custom_litgpt_dataloader/data_util.py b/custom_litgpt_dataloader/data_util.py new file mode 100644 index 0000000000..304ee69558 --- /dev/null +++ b/custom_litgpt_dataloader/data_util.py @@ -0,0 +1,93 @@ +from lit_gpt.packed_dataset import CombinedDataset, PackedDataset +from lit_gpt.utils import CycleIterator, chunked_cross_entropy, num_parameters +from torch.utils.data import DataLoader +from pathlib import Path +from axonn import axonn as ax +import torch.distributed as dist +from typing import Tuple, Union, Optional +from lit_gpt.utils import CycleIterator + +data_config = [ + ("train_slimpajama", 69.3584), + ("train_starcoder", 30.6), + # 0.693584, 0.306416) + # ("c4", 15.0), + # ("cc", 67.0), + # ("github", 4.5), + # ("stackexchange", 2.0), + # ("wikipedia", 4.5), +] + + +def create_dataloader( + batch_size: int, block_size: int, data_dir: Path, shuffle: bool = True, seed: int = 12345 +) -> DataLoader: + datasets = [] + for prefix, _ in data_config: + filenames = list(data_dir.glob(f"{prefix}*")) + if not filenames: + raise FileNotFoundError( + f"No files found at {str(data_dir)} with prefix {prefix}. Did you forget to run `prepare_redpajama.py`?" + ) + dataset = PackedDataset( + filenames, + n_chunks=4, + block_size=block_size, + shuffle=shuffle, + seed=seed, + num_processes=ax.config.G_data, + process_rank=ax.config.data_parallel_rank, + ) + datasets.append(dataset) + + if not datasets: + raise RuntimeError( + f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." + ) + + weights = [weight for _, weight in data_config] + sum_weights = sum(weights) + weights = [el / sum_weights for el in weights] + + combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) + + return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) + + +def create_dataloaders( + batch_size: int, + block_size: int, + train_data_dir: Path = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample"), + val_data_dir: Optional[Path] = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample"), + seed: int = 12345, +) -> Tuple[DataLoader, DataLoader]: + # Increase by one because we need the next word as well + effective_block_size = block_size + 1 + train_dataloader = create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + data_dir=train_data_dir, + shuffle=True, + seed=seed, + ) + val_dataloader = ( + create_dataloader( + batch_size=batch_size, + block_size=effective_block_size, + data_dir=val_data_dir, + shuffle=False, + seed=seed, + ) + if val_data_dir + else None + ) + return CycleIterator(train_dataloader), CycleIterator(val_dataloader) + +if __name__ == "__main__": + ax.init(G_inter=1, G_data=1, G_intra_r=8) + train_loader, val_loader = create_dataloaders( + batch_size=32, + block_size=1024, #seuqnce length? + ) + data = next(train_loader) + print(dist.get_rank(), ":", data.view(-1)[:5]) From f99a7e5b88e28177209823da4b646a26914997b4 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sat, 27 Jan 2024 11:52:59 -0500 Subject: [PATCH 03/25] test scripts on 128 GPUs --- custom_litgpt_dataloader/data_util.py | 3 +-- examples/run_axonn_amd_tinyllama.sh | 23 ++++++++++++++++------- megatron/training.py | 2 +- pretrain_gpt.py | 5 +---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/custom_litgpt_dataloader/data_util.py b/custom_litgpt_dataloader/data_util.py index 304ee69558..93aa0e5cd6 100644 --- a/custom_litgpt_dataloader/data_util.py +++ b/custom_litgpt_dataloader/data_util.py @@ -1,11 +1,10 @@ from lit_gpt.packed_dataset import CombinedDataset, PackedDataset -from lit_gpt.utils import CycleIterator, chunked_cross_entropy, num_parameters +from lit_gpt.utils import CycleIterator from torch.utils.data import DataLoader from pathlib import Path from axonn import axonn as ax import torch.distributed as dist from typing import Tuple, Union, Optional -from lit_gpt.utils import CycleIterator data_config = [ ("train_slimpajama", 69.3584), diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 43445e857b..40aa85d2f3 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -5,20 +5,29 @@ echo "This TinyLLAMA script will work for <=512 GPUs." -## loading python venv -module load cray-python -. /lustre/orion/scratch/ssingh37/csc547/venv_axonn_pt_2.1/bin/activate +module load PrgEnv-cray +module load cray-python/3.9.13.1 +. /ccs/home/ssingh37/axonn_venv/bin/activate module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses +module load libfabric +## these lines enable CUDA aware MPI +#module load craype-accel-amd-gfx90a export MPICH_GPU_SUPPORT_ENABLED=0 +#export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib" +## this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/ssingh37/aws-ofi-rccl/build/lib" +#export NCCL_DEBUG=INFO export FI_CXI_ATS=0 export HSA_FORCE_FINE_GRAIN_PCIE=1 +#export NCCL_SOCKET_IFNAME=hsn +# super important +#export NCCL_NET_GDR_LEVEL=4 +#export NCCL_P2P_LEVEL=4 +## this improves cross node bandwidth for some cases export NCCL_CROSS_NIC=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 -## point this to the AWS plugin (you should have compiled this previously) -export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/lustre/orion/scratch/ssingh37/csc547/aws-ofi-rccl/build/lib" - NNODES=$SLURM_JOB_NUM_NODES GPUS_PER_NODE=8 ## change as per your machine GPUS=$(( NNODES * GPUS_PER_NODE )) @@ -87,7 +96,7 @@ GPT_ARGS=" --global-batch-size ${GLOBAL_BATCH_SIZE} \ --lr 4.0e-4 \ --train-iters ${TRAIN_ITERS} \ - --lr-decay-iters 320000 \ + --lr-decay-iters ${TRAIN_ITERS} \ --lr-decay-style cosine \ --min-lr 4.0e-5 \ --weight-decay 1e-1 \ diff --git a/megatron/training.py b/megatron/training.py index 2c632ab86f..0064970a97 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -655,7 +655,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[nan_iters_key]) - #log_string += ' theoretical FLOP/s: {:.3f} TFLOP/s | '.format(get_flops(elapsed_time_per_iteration)) + log_string += ' theoretical FLOP/s: {:.3f} TFLOP/s | '.format(get_flops(elapsed_time_per_iteration)) #log_string += ' model size: {:.3f} B params | '.format(get_params()) curr, peak = get_mem() log_string += ' memory used by tensors {:.3f} GB ( peak {:.3f} GB)'.format(curr, peak) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 1044c21f61..cce7add610 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" - +from mpi4py import MPI import os import torch from functools import partial @@ -179,9 +179,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): def set_device_and_init_torch_dist(): - from mpi4py import MPI - import os - MPI.Init() world_rank = MPI.COMM_WORLD.Get_rank() world_size = MPI.COMM_WORLD.Get_size() From 4105129e468fa54b4d8a0f2da11addf3eb008640 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sat, 27 Jan 2024 12:24:32 -0500 Subject: [PATCH 04/25] remove harcoded output file from script --- examples/run_axonn_amd_tinyllama.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 40aa85d2f3..5d929cc5b1 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -1,7 +1,6 @@ #!/bin/bash #SBATCH -p batch #SBATCH -A CSC569 -#SBATCH -o /lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/logs/test.out echo "This TinyLLAMA script will work for <=512 GPUs." From a43c4d090e0b2f2ff2cd2862ae51c89a751ec8fd Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 10:49:26 -0500 Subject: [PATCH 05/25] change default dataset location name --- custom_litgpt_dataloader/data_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_litgpt_dataloader/data_util.py b/custom_litgpt_dataloader/data_util.py index 93aa0e5cd6..41cff40d72 100644 --- a/custom_litgpt_dataloader/data_util.py +++ b/custom_litgpt_dataloader/data_util.py @@ -56,8 +56,8 @@ def create_dataloader( def create_dataloaders( batch_size: int, block_size: int, - train_data_dir: Path = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample"), - val_data_dir: Optional[Path] = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample"), + train_data_dir: Path = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample_tinyllama_tokd"), + val_data_dir: Optional[Path] = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample_tinyllama_tokd"), seed: int = 12345, ) -> Tuple[DataLoader, DataLoader]: # Increase by one because we need the next word as well From 522320a7eee4562845edfc98b1b7656abf40b7b5 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 10:49:54 -0500 Subject: [PATCH 06/25] make workers 0 and get rank from slurm --- examples/run_axonn_amd.sh | 177 ---------------------------- examples/run_axonn_amd_tinyllama.sh | 8 +- 2 files changed, 5 insertions(+), 180 deletions(-) delete mode 100755 examples/run_axonn_amd.sh diff --git a/examples/run_axonn_amd.sh b/examples/run_axonn_amd.sh deleted file mode 100755 index 2f1ccc7899..0000000000 --- a/examples/run_axonn_amd.sh +++ /dev/null @@ -1,177 +0,0 @@ -#!/bin/bash -#SBATCH -p batch -#SBATCH -A CSC547 -#SBATCH -t 00:20:00 - -# Runs the "345M" parameter model - -module load cray-python -. /lustre/orion/scratch/ssingh37/csc547/venv_axonn_pt_2.1/bin/activate -module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses - -## these lines enable CUDA aware MPI -#module load craype-accel-amd-gfx90a -export MPICH_GPU_SUPPORT_ENABLED=0 -#export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib" -## this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) -export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/lustre/orion/scratch/ssingh37/csc547/aws-ofi-rccl/build/lib" -#export NCCL_DEBUG=INFO -export FI_CXI_ATS=0 -export HSA_FORCE_FINE_GRAIN_PCIE=1 -#export NCCL_SOCKET_IFNAME=hsn -# super important -#export NCCL_NET_GDR_LEVEL=4 -#export NCCL_P2P_LEVEL=4 -## this improves cross node bandwidth for some cases -export NCCL_CROSS_NIC=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -NNODES=$SLURM_JOB_NUM_NODES -GPUS_PER_NODE=8 ## change as per your machine -GPUS=$(( NNODES * GPUS_PER_NODE )) - -export MASTER_ADDR=$(hostname) -export MASTER_PORT=29500 - -# data/checkpoint args -DATA_DIR="/lustre/orion/csc547/proj-shared/parallel_deep_learning/book_corpus" - -CHECKPOINT_PATH="${DATA_DIR}/checkpoints" -VOCAB_FILE="${DATA_DIR}/gpt2-vocab.json" -MERGE_FILE="${DATA_DIR}/gpt2-merges.txt" -DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document" - - -## ARCHITECTURE DETAILS -# -# -# 5B -NUM_LAYERS=24 -NUM_HEADS=32 -HIDDEN_SIZE=4096 -# -# 10B -#NUM_LAYERS=32 -#NUM_HEADS=40 -#HIDDEN_SIZE=5120 - -# 20B -#NUM_LAYERS=32 -#NUM_HEADS=56 -#HIDDEN_SIZE=7168 - -# 40B -#NUM_LAYERS=38 -#NUM_HEADS=72 -#HIDDEN_SIZE=9216 -# -# 80B -#NUM_LAYERS=42 -#NUM_HEADS=96 -#HIDDEN_SIZE=12288 -# - -# 160B -#NUM_LAYERS=84 -#NUM_HEADS=96 -#HIDDEN_SIZE=12288 - -## PARALLELISM DETAILS -# -ROW_TENSOR_PARR=2 -COLUMN_TENSOR_PARR=1 -DEPTH_TENSOR_PARR=4 -PIPE_PARR=1 -CACHE_LAYERS=24 -OVERLAP=True - - -MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) -DP=$(( GPUS / MP )) -## BATCH SIZES - -GLOBAL_BATCH_SIZE=16 -MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) -SEQUENCE_LENGTH=2048 -TRAIN_ITERS=10 - - -config="r-${ROW_TENSOR_PARR}-c-${COLUMN_TENSOR_PARR}-d-${DEPTH_TENSOR_PARR}-g-${GPUS}" - -GPT_ARGS=" - --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ - --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ - --depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \ - --pipeline-model-parallel-size ${PIPE_PARR} \ - --num-layers ${NUM_LAYERS} \ - --hidden-size ${HIDDEN_SIZE} \ - --num-attention-heads ${NUM_HEADS} \ - --seq-length ${SEQUENCE_LENGTH} \ - --max-position-embeddings ${SEQUENCE_LENGTH} \ - --micro-batch-size ${MICRO_BATCH_SIZE} \ - --global-batch-size ${GLOBAL_BATCH_SIZE} \ - --lr 0.00015 \ - --train-iters ${TRAIN_ITERS} \ - --lr-decay-iters 320000 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --lr-warmup-fraction .01 \ - --clip-grad 1.0 \ - --bf16 \ - --no-gradient-accumulation-fusion \ - --use-amd \ - --recompute-granularity full \ - --recompute-method uniform \ - --recompute-num-layers 1 \ - --use-flash-attn \ -" -# --no-gradient-accumulation-fusion is neede on AMD -# --use-amd disables features incompatible with AMD - - -if [[ $OVERLAP == "True" ]] -then - GPT_ARGS="${GPT_ARGS} \ - --overlap-axonn-comm \ - --overlap-axonn-reduce-scatter \ - --overlap-axonn-all-gather\ - --num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}" -fi - - -DATA_ARGS=" - --data-path $DATA_PATH \ - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 949,50,1 \ - - --custom-dataloader -" - - -OUTPUT_ARGS=" - --log-interval 1 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 1 -" - -SCRIPT="python -u pretrain_lit_gpt.py \ - $GPT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ -" - - #--save $CHECKPOINT_PATH \ - #--load $CHECKPOINT_PATH - - -export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/ssingh37/csc547/lit-gpt-dev" -export OMP_NUM_THREADS=7 -run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ${SCRIPT}" - -echo ${run_cmd} -eval ${run_cmd} -set +x diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 5d929cc5b1..9008f287ec 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -32,6 +32,7 @@ GPUS_PER_NODE=8 ## change as per your machine GPUS=$(( NNODES * GPUS_PER_NODE )) export MASTER_ADDR=$(hostname) export MASTER_PORT=29500 +export WORLD_SIZE=${GPUS} # these are redundant for tiny-llams, so ignore @@ -41,7 +42,7 @@ MERGE_FILE="${DATA_DIR}/gpt2-merges.txt" DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document" # we will save and load model checkpoints here -CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" +CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints_2" #TODO: tensorboard logging #TENSORBOARD_DIR="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/logs" @@ -135,7 +136,8 @@ DATA_ARGS=" --vocab-file $VOCAB_FILE \ --merge-file $MERGE_FILE \ --split 949,50,1 \ - --custom-dataloader + --custom-dataloader \ + --num-workers 0 " @@ -158,7 +160,7 @@ SCRIPT="python -u pretrain_gpt.py \ export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/ssingh37/csc547/lit-gpt-dev" export OMP_NUM_THREADS=7 -run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ${SCRIPT}" +run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ./examples/get_rank_from_slurm.sh ${SCRIPT}" echo ${run_cmd} eval ${run_cmd} From 90f4e9f85add4d80edc97584e269e875d9ed5269 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 10:50:04 -0500 Subject: [PATCH 07/25] remove MPI dependency for real --- examples/get_rank_from_slurm.sh | 4 ++++ 1 file changed, 4 insertions(+) create mode 100755 examples/get_rank_from_slurm.sh diff --git a/examples/get_rank_from_slurm.sh b/examples/get_rank_from_slurm.sh new file mode 100755 index 0000000000..881bdc4a0e --- /dev/null +++ b/examples/get_rank_from_slurm.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# select_gpu_device wrapper script +export RANK=${SLURM_PROCID} +exec $* From 1ea6b8eab6f81da111da67343ed981141ba35199 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 10:51:05 -0500 Subject: [PATCH 08/25] init torch directly and remove mpi4py import --- pretrain_gpt.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index cce7add610..bea5b9626d 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -1,7 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" -from mpi4py import MPI import os import torch from functools import partial @@ -178,7 +177,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): return train_ds, valid_ds, test_ds -def set_device_and_init_torch_dist(): +def set_device_and_init_torch_dist_mpi(): world_rank = MPI.COMM_WORLD.Get_rank() world_size = MPI.COMM_WORLD.Get_size() @@ -203,9 +202,12 @@ def set_device_and_init_torch_dist(): os.environ["WORLD_SIZE"] = str(world_size) + if __name__ == "__main__": - set_device_and_init_torch_dist() + #set_device_and_init_torch_dist_mpi() #torch.cuda.set_per_process_memory_fraction(0.5) # 40GB + # env variables being set in slurm + torch.distributed.init_process_group() pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder, From a2feee147238c30558a5d68de371fee6d8eeeea2 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 10:53:22 -0500 Subject: [PATCH 09/25] add venv baed setup for megatron axonn --- examples/install_everything_on_frontier.sh | 68 ++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 examples/install_everything_on_frontier.sh diff --git a/examples/install_everything_on_frontier.sh b/examples/install_everything_on_frontier.sh new file mode 100644 index 0000000000..ac7359d65e --- /dev/null +++ b/examples/install_everything_on_frontier.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +# Setup Virtual Environment +echo "Setting up Virtual Environment" +module load cray-python +python -m venv ./my-venv --system-site-packages +cd my-venv +. bin/activate + +# PyTorch +echo "Installing PyTorch" +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6 + +module load amd-mixed/5.6.0 +module load PrgEnv-cray + +# mpi4py +echo "Installing mpi4py" +module load craype-accel-amd-gfx90a +export MPICH_GPU_SUPPORT_ENABLED=1 +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib" +echo ${LD_LIBRARY_PATH} + +MPICC=CC python -m pip install --ignore-installed --no-cache-dir mpi4py + +# Flash Attention +echo "Installing Flash Attention" +git clone https://github.com/ROCmSoftwarePlatform/flash-attention +cd flash-attention +vi setup.py -c ':%s/c++20/c++17/g' -c ':wq' +CC=cc CXX=CC PYTORCH_ROCM_ARCH='gfx90a' GPU_ARCHS='gfx90a' pip install -v . + +# Apex +echo "Installing Apex" +cd .. +git clone https://github.com/ROCmSoftwarePlatform/apex +cd apex +git checkout release/1.1.0 +CC=cc CXX=CC PYTORCH_ROCM_ARCH='gfx90a' GPU_ARCHS='gfx90a' python setup.py install --cpp_ext --cuda_ext + +# RCCL Plugin +echo "Installing RCCL Plugin" +cd .. +git clone https://github.com/ROCmSoftwarePlatform/aws-ofi-rccl +cd aws-ofi-rccl +module load libtool +./autogen.sh +CC=cc CXX=CC ./configure --with-libfabric=/opt/cray/libfabric/1.15.0.0 --with-hip=/opt/rocm-5.6.0/ --with-rccl="$(dirname "$(pwd)")"/lib/python3.9/site-packages/torch/lib/ --prefix="$(dirname "$(pwd)")"/aws-ofi-rccl/build/ +CC=cc CXX=CC make -j install + +cd .. + +# AxoNN +echo "Installing AxoNN" +git clone https://github.com/axonn-ai/axonn.git +cd axonn +pip install -e . + +cd .. + +# Megatron-AxoNN +echo "Installing Megatron-AxoNN" +git clone https://github.com/axonn-ai/Megatron-AxoNN.git + +pip install regex + +echo "Done!" + From f0f89d00eac03f5901ae0c3cedcdcd2764dbec2a Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 11:01:29 -0500 Subject: [PATCH 10/25] add instructions for frontier --- examples/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 examples/README.md diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000..cb519fd82a --- /dev/null +++ b/examples/README.md @@ -0,0 +1,17 @@ +# How to setup on frontier + +## Installing all dependencies +``` +cd /lustre/orion/scratch/$(whoami)/csc569/ +bash install_everything_on_frontier.sh +``` + +This should work, let Siddharth know if it doesn't + + +## Training TinyLLaMA +To launch on 16 nodes (128 GPUs) for 2 hours +``` +sbatch -N 128 -t 02:00:00 examples/run_axonn_amd_tinyllama.sh +``` + From 8ba0bd5d9987757c2b62cf131ee916ddd5e05dc7 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 11:04:07 -0500 Subject: [PATCH 11/25] add branch name in README --- examples/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/README.md b/examples/README.md index cb519fd82a..439a2a32f6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -12,6 +12,7 @@ This should work, let Siddharth know if it doesn't ## Training TinyLLaMA To launch on 16 nodes (128 GPUs) for 2 hours ``` +## checkout the tiny-llama branch sbatch -N 128 -t 02:00:00 examples/run_axonn_amd_tinyllama.sh ``` From 7d256ed21caca4d359dd4b57e453778d662a7f05 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 14:23:43 -0500 Subject: [PATCH 12/25] add --lit-gpt-data-path as an argument --- custom_litgpt_dataloader/data_util.py | 10 ++- examples/run_axonn_amd_tinyllama.sh | 104 +++++++++++++++++--------- megatron/arguments.py | 3 + pretrain_gpt.py | 2 + 4 files changed, 79 insertions(+), 40 deletions(-) diff --git a/custom_litgpt_dataloader/data_util.py b/custom_litgpt_dataloader/data_util.py index 41cff40d72..0412110b9e 100644 --- a/custom_litgpt_dataloader/data_util.py +++ b/custom_litgpt_dataloader/data_util.py @@ -56,11 +56,13 @@ def create_dataloader( def create_dataloaders( batch_size: int, block_size: int, - train_data_dir: Path = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample_tinyllama_tokd"), - val_data_dir: Optional[Path] = Path("/lustre/orion/csc569/proj-shared/language_datasets/spj_star_combined_sample_tinyllama_tokd"), - seed: int = 12345, + train_data_dir: str, + val_data_dir: str, + seed: int = 12345, #this seed is independent of megatron's seeds ) -> Tuple[DataLoader, DataLoader]: # Increase by one because we need the next word as well + train_data_dir = Path(train_data_dir) + val_data_dir = Path(val_data_dir) effective_block_size = block_size + 1 train_dataloader = create_dataloader( batch_size=batch_size, @@ -86,7 +88,7 @@ def create_dataloaders( ax.init(G_inter=1, G_data=1, G_intra_r=8) train_loader, val_loader = create_dataloaders( batch_size=32, - block_size=1024, #seuqnce length? + block_size=1024, #sequence length ) data = next(train_loader) print(dist.get_rank(), ":", data.view(-1)[:5]) diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 9008f287ec..9f5db3d3f8 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -2,68 +2,77 @@ #SBATCH -p batch #SBATCH -A CSC569 + +userid=$(whoami) +# These are the two things you need to change as per your setup +# 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed +# this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/$userid/aws-ofi-rccl/build/lib" +# 2. Make PYTHONPATH point to your local clone of litgpt +export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/$userid/csc547/lit-gpt-dev" + +# The rest of the script should work as it is + echo "This TinyLLAMA script will work for <=512 GPUs." module load PrgEnv-cray module load cray-python/3.9.13.1 -. /ccs/home/ssingh37/axonn_venv/bin/activate +. /ccs/home/$userid/axonn_venv/bin/activate module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses module load libfabric -## these lines enable CUDA aware MPI -#module load craype-accel-amd-gfx90a export MPICH_GPU_SUPPORT_ENABLED=0 -#export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib" -## this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) -export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/ssingh37/aws-ofi-rccl/build/lib" -#export NCCL_DEBUG=INFO + +## some RCCL env variables export FI_CXI_ATS=0 export HSA_FORCE_FINE_GRAIN_PCIE=1 -#export NCCL_SOCKET_IFNAME=hsn -# super important -#export NCCL_NET_GDR_LEVEL=4 -#export NCCL_P2P_LEVEL=4 -## this improves cross node bandwidth for some cases export NCCL_CROSS_NIC=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 +## calculating the number of nodes and GPUs NNODES=$SLURM_JOB_NUM_NODES GPUS_PER_NODE=8 ## change as per your machine GPUS=$(( NNODES * GPUS_PER_NODE )) + +# setting variables for torch.distributed export MASTER_ADDR=$(hostname) export MASTER_PORT=29500 export WORLD_SIZE=${GPUS} +# train_data_dir and val_data_dir are set to this as of now +DATADIR="/lustre/orion/csc569/proj-shared/language_datasets/" +DATASET="spj_star_combined_full_tinyllama_tokd" +DATAPATH="$DATADIR/$DATASET" + # these are redundant for tiny-llams, so ignore -DATA_DIR="/lustre/orion/csc547/proj-shared/parallel_deep_learning/book_corpus" -VOCAB_FILE="${DATA_DIR}/gpt2-vocab.json" -MERGE_FILE="${DATA_DIR}/gpt2-merges.txt" -DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document" +MEGATRON_TOKENIZER_DIR="/lustre/orion/proj-shared/csc569/book_corpus_megatron" +VOCAB_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-vocab.json" +MERGE_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-merges.txt" + # we will save and load model checkpoints here -CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints_2" +# if these are non-empty training will restart from the latest checkpoint here +# else training will start from scratch +CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" -#TODO: tensorboard logging -#TENSORBOARD_DIR="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/logs" -#mkdir -p ${TENSORBOARD_DIR} -# tiny-llama1.1B +# tiny-llama1.1B architecture shapes # https://github.com/azshue/lit-gpt-dev/blob/tiny-llama/lit_gpt/config.py -# -GLOBAL_BATCH_SIZE=512 -SEQUENCE_LENGTH=2048 NUM_LAYERS=22 NUM_HEADS=32 HIDDEN_SIZE=2048 FFN_HIDDEN_SIZE=5632 NUM_QUERY_GROUPS=4 -TOKENS_IN_BILLIONS=3000 +# batch size, seq length, and iterations +GLOBAL_BATCH_SIZE=512 +SEQUENCE_LENGTH=2048 +TOKENS_IN_BILLIONS=3000 TRAIN_ITERS=$(( TOKENS_IN_BILLIONS * 1000000000 / GLOBAL_BATCH_SIZE / SEQUENCE_LENGTH + 100 )) echo "Number of training iterations : ${TRAIN_ITERS}" -## AxoNN args +## AxoNN parallelism args ## These do not affect the science ROW_TENSOR_PARR=1 COLUMN_TENSOR_PARR=1 @@ -78,8 +87,16 @@ MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) DP=$(( GPUS / MP )) MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) - -config="r-${ROW_TENSOR_PARR}-c-${COLUMN_TENSOR_PARR}-d-${DEPTH_TENSOR_PARR}-g-${GPUS}" +# The following args enable LLaMA +# --swiglu makes ParallelMLP equivalent to LLAMAMLP +# --group-query-attention - enables group query attention +# --num-query-groups - number of query groups for group query attention +# --normalization RMSNorm - switch from layernorm to RMSNorm (someone confirm?) +# --use-rotary-position-embeddings - use RoPE embeddings instead of learned position embeddings +# +# The following args disable features not compatible with AMD +# --no-gradient-accumulation-fusion +# --use-amd GPT_ARGS=" --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ @@ -117,10 +134,9 @@ GPT_ARGS=" --group-query-attention \ --num-query-groups ${NUM_QUERY_GROUPS} " -# --no-gradient-accumulation-fusion is neede on AMD -# --use-amd disables features incompatible with AMD -# --swiglu makes ParallelMLP equivalent to LLAMAMLP +## AxoNN specific args for communication optimizations +# these do not affect the ML science if [[ $OVERLAP == "True" ]] then GPT_ARGS="${GPT_ARGS} \ @@ -130,17 +146,34 @@ then --num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}" fi -# the data-path vocab-file and marge-file args are redundant here -# the custom-dataloader is switching to the lit gpt dataloader +# --lit-gpt-data-path - is pointing to your dataset +# currently both train and val splits are taken fron --data-path +# the --custom-dataloader argument bypasses megatron's dataloaders +# --num-workers 0 - disables multiprocesses dataloading +# which can hang jobs at scale + DATA_ARGS=" + --lit-gpt-data-path $DATAPATH \ + --custom-dataloader \ + --num-workers 0 +" + +# these args are for megatron dataloaders +# these are not needed for litgpt, but not passing them +# might give you errors +# THESE DO NOTHING +REDUNDANT_DATA_ARGS=" --vocab-file $VOCAB_FILE \ --merge-file $MERGE_FILE \ --split 949,50,1 \ - --custom-dataloader \ - --num-workers 0 " +DATA_ARGS="${DATA_ARGS} ${REDUNDANT_DATA_ARGS}" +# --eval-interval 1000 - do validation after every 1000 arguments +# --eval-iters 100 - do validation for 100 iterations +# --save-interval 1000 - save the model after every 1000 iterations +# --log-interval 1 - print iteration lossees after every 1 iteration OUTPUT_ARGS=" --log-interval 1 \ --save-interval 1000 \ @@ -158,7 +191,6 @@ SCRIPT="python -u pretrain_gpt.py \ " -export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/ssingh37/csc547/lit-gpt-dev" export OMP_NUM_THREADS=7 run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ./examples/get_rank_from_slurm.sh ${SCRIPT}" diff --git a/megatron/arguments.py b/megatron/arguments.py index 62d1ae3598..eb6aae38f9 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1164,6 +1164,9 @@ def _add_data_args(parser): group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') + ## add separate argument for lit gpt data paths + group.add_argument('--lit-gpt-data-path', type=str, + help="data path for custom lit gpt dataloaders") return parser diff --git a/pretrain_gpt.py b/pretrain_gpt.py index bea5b9626d..94777fe07b 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -149,6 +149,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): train_iterator, valid_iterator = create_dataloaders( batch_size= args.micro_batch_size, block_size= args.seq_length, + train_data_dir = args.lit_gpt_data_path, + val_data_dir = args.lit_gpt_data_path, ) # these flags are set within megatron in From 1b0f9523cc457934e535c0fcefb76417338bd267 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 14:27:52 -0500 Subject: [PATCH 13/25] update README --- examples/README.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 439a2a32f6..645a2e447b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,6 +1,8 @@ # How to setup on frontier ## Installing all dependencies +# Note that this is a python virtual environment based setup +# You might need to change this a bit for conda ``` cd /lustre/orion/scratch/$(whoami)/csc569/ bash install_everything_on_frontier.sh @@ -8,8 +10,20 @@ bash install_everything_on_frontier.sh This should work, let Siddharth know if it doesn't - ## Training TinyLLaMA +First checkout the tiny-llama branch of megatron-axonn +Then open `examples/run_axonn_amd_tinyllama.sh`, and change the following + +``` +# These are the two things you need to change as per your setup +# 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed +# this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/path/to/aws-ofi-rccl/build/lib" +# 2. Make PYTHONPATH point to your local clone of litgpt +export PYTHONPATH="$PYTHONPATH:/path/to/lit-gpt-dev" +``` + +Now you are ready to train. To launch on 16 nodes (128 GPUs) for 2 hours ``` ## checkout the tiny-llama branch From 2b6ddc4c2cea3b592f8821e60320dc1de43fbc1d Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 14:29:38 -0500 Subject: [PATCH 14/25] update README --- examples/README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/README.md b/examples/README.md index 645a2e447b..c6c33fc5cc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,8 +1,14 @@ # How to setup on frontier ## Installing all dependencies -# Note that this is a python virtual environment based setup -# You might need to change this a bit for conda +Note that this is a python virtual environment based setup +You might need to change this a bit for conda + +Also this assumes that you are starting from scratch and have no venv/conda +environment enabled. + +We are going to install everything on scratch. + ``` cd /lustre/orion/scratch/$(whoami)/csc569/ bash install_everything_on_frontier.sh From 0582227267c7d3f4c447786f301aa26a92f877f7 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 29 Jan 2024 14:31:14 -0500 Subject: [PATCH 15/25] minor --- examples/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/README.md b/examples/README.md index c6c33fc5cc..073b601ea6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -17,7 +17,7 @@ bash install_everything_on_frontier.sh This should work, let Siddharth know if it doesn't ## Training TinyLLaMA -First checkout the tiny-llama branch of megatron-axonn +First checkout the tiny-llama branch of Megatron-AxoNN. Then open `examples/run_axonn_amd_tinyllama.sh`, and change the following ``` @@ -30,9 +30,11 @@ export PYTHONPATH="$PYTHONPATH:/path/to/lit-gpt-dev" ``` Now you are ready to train. -To launch on 16 nodes (128 GPUs) for 2 hours +To launch on 16 nodes (128 GPUs) for 2 hours: ``` ## checkout the tiny-llama branch -sbatch -N 128 -t 02:00:00 examples/run_axonn_amd_tinyllama.sh +sbatch -N 128 -o /path/to/output/file -t 02:00:00 examples/run_axonn_amd_tinyllama.sh ``` + + From bc55cfba80fd20996b8a45ff4cd1e5b1c762de12 Mon Sep 17 00:00:00 2001 From: Neel Jain <90730659+neelsjain@users.noreply.github.com> Date: Fri, 9 Feb 2024 15:38:42 -0500 Subject: [PATCH 16/25] updated hparam (#21) Co-authored-by: Neel Jain --- examples/run_axonn_amd_tinyllama_hparam.sh | 245 +++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100755 examples/run_axonn_amd_tinyllama_hparam.sh diff --git a/examples/run_axonn_amd_tinyllama_hparam.sh b/examples/run_axonn_amd_tinyllama_hparam.sh new file mode 100755 index 0000000000..c47ecf706e --- /dev/null +++ b/examples/run_axonn_amd_tinyllama_hparam.sh @@ -0,0 +1,245 @@ +#!/bin/bash +#SBATCH -p batch +#SBATCH -A CSC569 + +# tiny_llama = [ +# dict( +# name="tiny-llama-1.1b{}", +# hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"), +# block_size=2048, #### CHECKED (Seq Length) +# vocab_size=32000, #### NEED TO ADD VOCAB FILES +# padding_multiple=64, ### NOT RELEVANT +# n_layer=22, ### CHECKED (NUM_LAYERS) +# n_head=32, ### CHECKED (NUM_HEADS) +# n_embd=2048, ### CHECKED (HIDDEN_SIZE) +# rotary_percentage=1.0, ### TODO: CHECK IF -- USING ROTARY (I think the default is 1: https://github.com/search?q=repo%3Aaxonn-ai%2FMegatron-AxoNN+rotary_percent&type=code) +# parallel_residual=False, ### TODO: CHECK (https://github.com/azshue/lit-gpt-dev/blob/99fb9363646bfacb686f72f58274392e6036ad6c/lit_gpt/model.py#L157 and apply_residual_connection_post_layernorm are the same) +# bias=False, ### TODO: CHECK "disable-bias-linear" I think. This is the bias for the linear layers +# _norm_class="RMSNorm", ### CHECKED "--normalization RMSNorm" +# norm_eps=1e-5, ### TODO: UNLCLEAR WHERE THIS IS -- I think this is fine (https://github.com/search?q=repo%3Aaxonn-ai%2FMegatron-AxoNN%20norm_eps&type=code) +# _mlp_class="LLaMAMLP", ### CHECKED "From Line 112, # --swiglu makes ParallelMLP equivalent to LLAMAMLP" +# intermediate_size=5632, ### CHECKED "FFN_HIDDEN_SIZE" +# n_query_groups=4, #### CHECKED: NUM_QUERY_GROUPS +# ) +# ] +### WE want global batch size of 4M so 4000000/2048 +#### We are gonna copy Olma's BS of 4M +# global_batch_size = 2048 #NEEL: UPDATED IN BASH SCRIPT +# learning_rate = 4e-4 #NEEL: Checked "--lr 4.0e-4" +#### THIS COULD BE SET ACCORDING TO HOW MANY GPUs we want to use +# micro_batch_size = 8 +# max_tokens = int(1e12) #NEEL: UPDATED IN BASH SCRIPT +# warmup_steps = 2000 # We are gonna use tinyllama warmup steps +#### BELOW ARE IRRELVANT #### +# log_step_interval = 1 +# eval_iters = 100 +# save_step_interval = 1000 +# eval_step_interval = 1000 +#### ABOVE ARE IRRELVANT #### + +# weight_decay = 1e-1 ### Neel: CHECKED "weight-decay 1e-1" +# beta1 = 0.9 ### Neel: CHECKED +# beta2 = 0.95 ### Neel: CHECKED +# grad_clip = 1.0 ### Neel: CHECKED +# decay_lr = True <--- This is irrevalant +# min_lr = 4e-5 ### Neel: CHECKED + +userid=$(whoami) +# These are the two things you need to change as per your setup +# 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed +# this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/$userid/aws-ofi-rccl/build/lib" +# 2. Make PYTHONPATH point to your local clone of litgpt +export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/$userid/csc547/lit-gpt-dev" + +# The rest of the script should work as it is + +echo "This TinyLLAMA script will work for <=512 GPUs." + +module load PrgEnv-cray +module load cray-python/3.9.13.1 +. /ccs/home/$userid/axonn_venv/bin/activate +module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses +module load libfabric + +export MPICH_GPU_SUPPORT_ENABLED=0 + +## some RCCL env variables +export FI_CXI_ATS=0 +export HSA_FORCE_FINE_GRAIN_PCIE=1 +export NCCL_CROSS_NIC=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +## calculating the number of nodes and GPUs +NNODES=$SLURM_JOB_NUM_NODES +GPUS_PER_NODE=8 ## change as per your machine +GPUS=$(( NNODES * GPUS_PER_NODE )) + +# setting variables for torch.distributed +export MASTER_ADDR=$(hostname) +export MASTER_PORT=29500 +export WORLD_SIZE=${GPUS} + +# train_data_dir and val_data_dir are set to this as of now +DATADIR="/lustre/orion/csc569/proj-shared/language_datasets/" +DATASET="spj_star_combined_full_tinyllama_tokd" +DATAPATH="$DATADIR/$DATASET" + + +# these are redundant for tiny-llams, so ignore +########## TODO: FIX TO TINY LLAMA VOCAB ############# +MEGATRON_TOKENIZER_DIR="/lustre/orion/proj-shared/csc569/book_corpus_megatron" +VOCAB_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-vocab.json" +MERGE_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-merges.txt" + + +# we will save and load model checkpoints here +# if these are non-empty training will restart from the latest checkpoint here +# else training will start from scratch +CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" + + +# tiny-llama1.1B architecture shapes +# https://github.com/azshue/lit-gpt-dev/blob/tiny-llama/lit_gpt/config.py +NUM_LAYERS=22 +NUM_HEADS=32 +HIDDEN_SIZE=2048 +FFN_HIDDEN_SIZE=5632 +NUM_QUERY_GROUPS=4 + +# batch size, seq length, and iterations +GLOBAL_BATCH_SIZE=2048 ## Neel: 2048x2048 = 4M per batch +SEQUENCE_LENGTH=2048 +TOKENS_IN_BILLIONS=1000 ### Neel: Changed 1T ##### +TRAIN_ITERS=$(( TOKENS_IN_BILLIONS * 1000000000 / GLOBAL_BATCH_SIZE / SEQUENCE_LENGTH + 100 )) +echo "Number of training iterations : ${TRAIN_ITERS}" + +## AxoNN parallelism args +## These do not affect the science +ROW_TENSOR_PARR=1 +COLUMN_TENSOR_PARR=1 +DEPTH_TENSOR_PARR=2 +PIPE_PARR=1 +CACHE_LAYERS=22 +OVERLAP=True + + +## DERIVED ARGUMENTS (ignore) +MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) +DP=$(( GPUS / MP )) +MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) + +# The following args enable LLaMA +# --swiglu makes ParallelMLP equivalent to LLAMAMLP +# --group-query-attention - enables group query attention +# --num-query-groups - number of query groups for group query attention +# --normalization RMSNorm - switch from layernorm to RMSNorm (someone confirm?) +# --use-rotary-position-embeddings - use RoPE embeddings instead of learned position embeddings +# +# The following args disable features not compatible with AMD +# --no-gradient-accumulation-fusion +# --use-amd + + + + + +GPT_ARGS=" + --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ + --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ + --depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \ + --pipeline-model-parallel-size ${PIPE_PARR} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_HEADS} \ + --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ + --seq-length ${SEQUENCE_LENGTH} \ + --max-position-embeddings ${SEQUENCE_LENGTH} \ + --micro-batch-size ${MICRO_BATCH_SIZE} \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --lr 4.0e-4 \ + --train-iters ${TRAIN_ITERS} \ + --lr-decay-iters ${TRAIN_ITERS} \ + --lr-decay-style cosine \ + --min-lr 4.0e-5 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 2000 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --bf16 \ + --no-gradient-accumulation-fusion \ + --use-amd \ + --recompute-granularity full \ + --recompute-method uniform \ + --recompute-num-layers 1 \ + --use-flash-attn \ + --swiglu \ + --use-rotary-position-embeddings \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups ${NUM_QUERY_GROUPS} +" + +## AxoNN specific args for communication optimizations +# these do not affect the ML science +if [[ $OVERLAP == "True" ]] +then + GPT_ARGS="${GPT_ARGS} \ + --overlap-axonn-comm \ + --overlap-axonn-reduce-scatter \ + --overlap-axonn-all-gather\ + --num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}" +fi + +# --lit-gpt-data-path - is pointing to your dataset +# currently both train and val splits are taken fron --data-path +# the --custom-dataloader argument bypasses megatron's dataloaders +# --num-workers 0 - disables multiprocesses dataloading +# which can hang jobs at scale + +DATA_ARGS=" + --lit-gpt-data-path $DATAPATH \ + --custom-dataloader \ + --num-workers 0 +" + +# these args are for megatron dataloaders +# these are not needed for litgpt, but not passing them +# might give you errors +# THESE DO NOTHING +REDUNDANT_DATA_ARGS=" + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 \ +" + +DATA_ARGS="${DATA_ARGS} ${REDUNDANT_DATA_ARGS}" + +# --eval-interval 1000 - do validation after every 1000 arguments +# --eval-iters 100 - do validation for 100 iterations +# --save-interval 1000 - save the model after every 1000 iterations +# --log-interval 1 - print iteration lossees after every 1 iteration +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 1000 \ + --eval-interval 1000 \ + --eval-iters 100 \ +" + +SCRIPT="python -u pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH +" + + +export OMP_NUM_THREADS=7 +run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ./examples/get_rank_from_slurm.sh ${SCRIPT}" + +echo ${run_cmd} +eval ${run_cmd} +set +x \ No newline at end of file From 76efc9a54ff219c251d7ce8f4c9a7bf67624b1a8 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 9 Feb 2024 16:31:23 -0500 Subject: [PATCH 17/25] sanity check on rewinding dataloaders --- pretrain_gpt.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 94777fe07b..b7ea61d8e0 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -2,6 +2,7 @@ """Pretrain GPT""" import os +import time import torch from functools import partial from megatron import get_args @@ -118,12 +119,6 @@ def forward_step(data_iterator, model): labels = drop(labels, skip_channels=True) loss_mask = drop(loss_mask, skip_channels=True) position_ids = drop(position_ids, skip_channels=True) - #print(tokens.shape) - #print(labels.shape) - #print(loss_mask.shape) - #print(attention_mask.shape) - #print(position_ids.shape) - #exit() if args.overlap_axonn_comm: ctx = partial(optimize_communication, @@ -141,6 +136,8 @@ def forward_step(data_iterator, model): return output_tensor, partial(loss_func, loss_mask) + + def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() @@ -152,12 +149,25 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): train_data_dir = args.lit_gpt_data_path, val_data_dir = args.lit_gpt_data_path, ) - # these flags are set within megatron in # the OG dataloader args.do_train = True args.do_valid = True args.do_test = False + if args.consumed_train_samples > 0 and train_iterator is not None: + print_rank_0(f"Rewinding dataloader to {args.consumed_train_samples} samples") + train_iterator_consumed_samples = 0 + fake_iters = 0 + start = time.time() + while train_iterator_consumed_samples < args.consumed_train_samples: + next(train_iterator) + train_iterator_consumed_samples += args.global_batch_size + fake_iters += 1 + if fake_iters % args.eval_interval == 0: + for _ in range(args.eval_iters): + next(valid_iterator) + end = time.time() + print_rank_0(f"Time for rewinding the dataloader on rank 0 = {end-start:.2f} s") return train_iterator, valid_iterator else: From 88e859fb9d1d3257ee54886c665591d982ae0a17 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 9 Feb 2024 16:32:30 -0500 Subject: [PATCH 18/25] move venv to burst buffer --- examples/run_axonn_amd_tinyllama.sh | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 9f5db3d3f8..749145e35f 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -2,6 +2,10 @@ #SBATCH -p batch #SBATCH -A CSC569 +## calculating the number of nodes and GPUs +NNODES=$SLURM_JOB_NUM_NODES +GPUS_PER_NODE=8 ## change as per your machine +GPUS=$(( NNODES * GPUS_PER_NODE )) userid=$(whoami) # These are the two things you need to change as per your setup @@ -15,6 +19,14 @@ export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/$userid/csc547/lit-gpt-dev" echo "This TinyLLAMA script will work for <=512 GPUs." +echo "moving environment to burst buffer" +## load venv onto burst buffer +srun -N $NNODES --ntasks-per-node=1 prepare_venv.sh +## delete old symbolic link +rm -rf ~/axonn_venv +## craete new symbolic link +ln -s /mnt/bb/ssingh37/axonn_venv ~/axonn_venv + module load PrgEnv-cray module load cray-python/3.9.13.1 . /ccs/home/$userid/axonn_venv/bin/activate @@ -29,10 +41,6 @@ export HSA_FORCE_FINE_GRAIN_PCIE=1 export NCCL_CROSS_NIC=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 -## calculating the number of nodes and GPUs -NNODES=$SLURM_JOB_NUM_NODES -GPUS_PER_NODE=8 ## change as per your machine -GPUS=$(( NNODES * GPUS_PER_NODE )) # setting variables for torch.distributed export MASTER_ADDR=$(hostname) From 2c58c30905ad2fa5691e5f6862a34c56709330d6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 9 Feb 2024 17:17:12 -0500 Subject: [PATCH 19/25] switch to llama2 tokenizer --- examples/run_axonn_amd_tinyllama.sh | 78 +++++-- examples/run_axonn_amd_tinyllama_hparam.sh | 245 --------------------- prepare_venv.sh | 9 + 3 files changed, 66 insertions(+), 266 deletions(-) delete mode 100755 examples/run_axonn_amd_tinyllama_hparam.sh create mode 100755 prepare_venv.sh diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 749145e35f..a8552082c3 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -1,12 +1,56 @@ #!/bin/bash #SBATCH -p batch #SBATCH -A CSC569 +#SBATCH -C nvme + +# tiny_llama = [ +# dict( +# name="tiny-llama-1.1b{}", +# hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"), +# block_size=2048, #### CHECKED (Seq Length) +# vocab_size=32000, #### NEED TO ADD VOCAB FILES +# padding_multiple=64, ### NOT RELEVANT +# n_layer=22, ### CHECKED (NUM_LAYERS) +# n_head=32, ### CHECKED (NUM_HEADS) +# n_embd=2048, ### CHECKED (HIDDEN_SIZE) +# rotary_percentage=1.0, ### TODO: CHECK IF -- USING ROTARY (I think the default is 1: https://github.com/search?q=repo%3Aaxonn-ai%2FMegatron-AxoNN+rotary_percent&type=code) +# parallel_residual=False, ### TODO: CHECK (https://github.com/azshue/lit-gpt-dev/blob/99fb9363646bfacb686f72f58274392e6036ad6c/lit_gpt/model.py#L157 and apply_residual_connection_post_layernorm are the same) +# bias=False, ### TODO: CHECK "disable-bias-linear" I think. This is the bias for the linear layers +# _norm_class="RMSNorm", ### CHECKED "--normalization RMSNorm" +# norm_eps=1e-5, ### TODO: UNLCLEAR WHERE THIS IS -- I think this is fine (https://github.com/search?q=repo%3Aaxonn-ai%2FMegatron-AxoNN%20norm_eps&type=code) +# _mlp_class="LLaMAMLP", ### CHECKED "From Line 112, # --swiglu makes ParallelMLP equivalent to LLAMAMLP" +# intermediate_size=5632, ### CHECKED "FFN_HIDDEN_SIZE" +# n_query_groups=4, #### CHECKED: NUM_QUERY_GROUPS +# ) +# ] +### WE want global batch size of 4M so 4000000/2048 +#### We are gonna copy Olma's BS of 4M +# global_batch_size = 2048 #NEEL: UPDATED IN BASH SCRIPT +# learning_rate = 4e-4 #NEEL: Checked "--lr 4.0e-4" +#### THIS COULD BE SET ACCORDING TO HOW MANY GPUs we want to use +# micro_batch_size = 8 +# max_tokens = int(1e12) #NEEL: UPDATED IN BASH SCRIPT +# warmup_steps = 2000 # We are gonna use tinyllama warmup steps +#### BELOW ARE IRRELVANT #### +# log_step_interval = 1 +# eval_iters = 100 +# save_step_interval = 1000 +# eval_step_interval = 1000 +#### ABOVE ARE IRRELVANT #### + +# weight_decay = 1e-1 ### Neel: CHECKED "weight-decay 1e-1" +# beta1 = 0.9 ### Neel: CHECKED +# beta2 = 0.95 ### Neel: CHECKED +# grad_clip = 1.0 ### Neel: CHECKED +# decay_lr = True <--- This is irrevalant +# min_lr = 4e-5 ### Neel: CHECKED ## calculating the number of nodes and GPUs NNODES=$SLURM_JOB_NUM_NODES GPUS_PER_NODE=8 ## change as per your machine GPUS=$(( NNODES * GPUS_PER_NODE )) + userid=$(whoami) # These are the two things you need to change as per your setup # 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed @@ -16,9 +60,10 @@ export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/$userid/aws-ofi-rccl/build/ export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/$userid/csc547/lit-gpt-dev" # The rest of the script should work as it is - echo "This TinyLLAMA script will work for <=512 GPUs." + +# This blob is setting up my python venv, ignore for conda builds echo "moving environment to burst buffer" ## load venv onto burst buffer srun -N $NNODES --ntasks-per-node=1 prepare_venv.sh @@ -26,10 +71,11 @@ srun -N $NNODES --ntasks-per-node=1 prepare_venv.sh rm -rf ~/axonn_venv ## craete new symbolic link ln -s /mnt/bb/ssingh37/axonn_venv ~/axonn_venv - module load PrgEnv-cray module load cray-python/3.9.13.1 . /ccs/home/$userid/axonn_venv/bin/activate + + module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses module load libfabric @@ -53,10 +99,9 @@ DATASET="spj_star_combined_full_tinyllama_tokd" DATAPATH="$DATADIR/$DATASET" -# these are redundant for tiny-llams, so ignore -MEGATRON_TOKENIZER_DIR="/lustre/orion/proj-shared/csc569/book_corpus_megatron" -VOCAB_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-vocab.json" -MERGE_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-merges.txt" +########## TODO: FIX TO TINY LLAMA VOCAB ############# +TOKENIZER_DIR="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/llama-tokenizer" +TOKENIZER_MODEL="${TOKENIZER_DIR}/tokenizer.model" # we will save and load model checkpoints here @@ -74,9 +119,9 @@ FFN_HIDDEN_SIZE=5632 NUM_QUERY_GROUPS=4 # batch size, seq length, and iterations -GLOBAL_BATCH_SIZE=512 +GLOBAL_BATCH_SIZE=32 ## Neel: 2048x2048 = 4M per batch SEQUENCE_LENGTH=2048 -TOKENS_IN_BILLIONS=3000 +TOKENS_IN_BILLIONS=1000 ### Neel: Changed 1T ##### TRAIN_ITERS=$(( TOKENS_IN_BILLIONS * 1000000000 / GLOBAL_BATCH_SIZE / SEQUENCE_LENGTH + 100 )) echo "Number of training iterations : ${TRAIN_ITERS}" @@ -106,6 +151,7 @@ MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) # --no-gradient-accumulation-fusion # --use-amd + GPT_ARGS=" --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ @@ -163,21 +209,11 @@ fi DATA_ARGS=" --lit-gpt-data-path $DATAPATH \ --custom-dataloader \ - --num-workers 0 + --num-workers 0 \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} " -# these args are for megatron dataloaders -# these are not needed for litgpt, but not passing them -# might give you errors -# THESE DO NOTHING -REDUNDANT_DATA_ARGS=" - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 949,50,1 \ -" - -DATA_ARGS="${DATA_ARGS} ${REDUNDANT_DATA_ARGS}" - # --eval-interval 1000 - do validation after every 1000 arguments # --eval-iters 100 - do validation for 100 iterations # --save-interval 1000 - save the model after every 1000 iterations diff --git a/examples/run_axonn_amd_tinyllama_hparam.sh b/examples/run_axonn_amd_tinyllama_hparam.sh deleted file mode 100755 index c47ecf706e..0000000000 --- a/examples/run_axonn_amd_tinyllama_hparam.sh +++ /dev/null @@ -1,245 +0,0 @@ -#!/bin/bash -#SBATCH -p batch -#SBATCH -A CSC569 - -# tiny_llama = [ -# dict( -# name="tiny-llama-1.1b{}", -# hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"), -# block_size=2048, #### CHECKED (Seq Length) -# vocab_size=32000, #### NEED TO ADD VOCAB FILES -# padding_multiple=64, ### NOT RELEVANT -# n_layer=22, ### CHECKED (NUM_LAYERS) -# n_head=32, ### CHECKED (NUM_HEADS) -# n_embd=2048, ### CHECKED (HIDDEN_SIZE) -# rotary_percentage=1.0, ### TODO: CHECK IF -- USING ROTARY (I think the default is 1: https://github.com/search?q=repo%3Aaxonn-ai%2FMegatron-AxoNN+rotary_percent&type=code) -# parallel_residual=False, ### TODO: CHECK (https://github.com/azshue/lit-gpt-dev/blob/99fb9363646bfacb686f72f58274392e6036ad6c/lit_gpt/model.py#L157 and apply_residual_connection_post_layernorm are the same) -# bias=False, ### TODO: CHECK "disable-bias-linear" I think. This is the bias for the linear layers -# _norm_class="RMSNorm", ### CHECKED "--normalization RMSNorm" -# norm_eps=1e-5, ### TODO: UNLCLEAR WHERE THIS IS -- I think this is fine (https://github.com/search?q=repo%3Aaxonn-ai%2FMegatron-AxoNN%20norm_eps&type=code) -# _mlp_class="LLaMAMLP", ### CHECKED "From Line 112, # --swiglu makes ParallelMLP equivalent to LLAMAMLP" -# intermediate_size=5632, ### CHECKED "FFN_HIDDEN_SIZE" -# n_query_groups=4, #### CHECKED: NUM_QUERY_GROUPS -# ) -# ] -### WE want global batch size of 4M so 4000000/2048 -#### We are gonna copy Olma's BS of 4M -# global_batch_size = 2048 #NEEL: UPDATED IN BASH SCRIPT -# learning_rate = 4e-4 #NEEL: Checked "--lr 4.0e-4" -#### THIS COULD BE SET ACCORDING TO HOW MANY GPUs we want to use -# micro_batch_size = 8 -# max_tokens = int(1e12) #NEEL: UPDATED IN BASH SCRIPT -# warmup_steps = 2000 # We are gonna use tinyllama warmup steps -#### BELOW ARE IRRELVANT #### -# log_step_interval = 1 -# eval_iters = 100 -# save_step_interval = 1000 -# eval_step_interval = 1000 -#### ABOVE ARE IRRELVANT #### - -# weight_decay = 1e-1 ### Neel: CHECKED "weight-decay 1e-1" -# beta1 = 0.9 ### Neel: CHECKED -# beta2 = 0.95 ### Neel: CHECKED -# grad_clip = 1.0 ### Neel: CHECKED -# decay_lr = True <--- This is irrevalant -# min_lr = 4e-5 ### Neel: CHECKED - -userid=$(whoami) -# These are the two things you need to change as per your setup -# 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed -# this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw) -export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/$userid/aws-ofi-rccl/build/lib" -# 2. Make PYTHONPATH point to your local clone of litgpt -export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/$userid/csc547/lit-gpt-dev" - -# The rest of the script should work as it is - -echo "This TinyLLAMA script will work for <=512 GPUs." - -module load PrgEnv-cray -module load cray-python/3.9.13.1 -. /ccs/home/$userid/axonn_venv/bin/activate -module load amd-mixed/5.6.0 #this should match with the rocm version your pytorch uses -module load libfabric - -export MPICH_GPU_SUPPORT_ENABLED=0 - -## some RCCL env variables -export FI_CXI_ATS=0 -export HSA_FORCE_FINE_GRAIN_PCIE=1 -export NCCL_CROSS_NIC=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -## calculating the number of nodes and GPUs -NNODES=$SLURM_JOB_NUM_NODES -GPUS_PER_NODE=8 ## change as per your machine -GPUS=$(( NNODES * GPUS_PER_NODE )) - -# setting variables for torch.distributed -export MASTER_ADDR=$(hostname) -export MASTER_PORT=29500 -export WORLD_SIZE=${GPUS} - -# train_data_dir and val_data_dir are set to this as of now -DATADIR="/lustre/orion/csc569/proj-shared/language_datasets/" -DATASET="spj_star_combined_full_tinyllama_tokd" -DATAPATH="$DATADIR/$DATASET" - - -# these are redundant for tiny-llams, so ignore -########## TODO: FIX TO TINY LLAMA VOCAB ############# -MEGATRON_TOKENIZER_DIR="/lustre/orion/proj-shared/csc569/book_corpus_megatron" -VOCAB_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-vocab.json" -MERGE_FILE="${MEGATRON_TOKENIZER_DIR}/gpt2-merges.txt" - - -# we will save and load model checkpoints here -# if these are non-empty training will restart from the latest checkpoint here -# else training will start from scratch -CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" - - -# tiny-llama1.1B architecture shapes -# https://github.com/azshue/lit-gpt-dev/blob/tiny-llama/lit_gpt/config.py -NUM_LAYERS=22 -NUM_HEADS=32 -HIDDEN_SIZE=2048 -FFN_HIDDEN_SIZE=5632 -NUM_QUERY_GROUPS=4 - -# batch size, seq length, and iterations -GLOBAL_BATCH_SIZE=2048 ## Neel: 2048x2048 = 4M per batch -SEQUENCE_LENGTH=2048 -TOKENS_IN_BILLIONS=1000 ### Neel: Changed 1T ##### -TRAIN_ITERS=$(( TOKENS_IN_BILLIONS * 1000000000 / GLOBAL_BATCH_SIZE / SEQUENCE_LENGTH + 100 )) -echo "Number of training iterations : ${TRAIN_ITERS}" - -## AxoNN parallelism args -## These do not affect the science -ROW_TENSOR_PARR=1 -COLUMN_TENSOR_PARR=1 -DEPTH_TENSOR_PARR=2 -PIPE_PARR=1 -CACHE_LAYERS=22 -OVERLAP=True - - -## DERIVED ARGUMENTS (ignore) -MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) -DP=$(( GPUS / MP )) -MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) - -# The following args enable LLaMA -# --swiglu makes ParallelMLP equivalent to LLAMAMLP -# --group-query-attention - enables group query attention -# --num-query-groups - number of query groups for group query attention -# --normalization RMSNorm - switch from layernorm to RMSNorm (someone confirm?) -# --use-rotary-position-embeddings - use RoPE embeddings instead of learned position embeddings -# -# The following args disable features not compatible with AMD -# --no-gradient-accumulation-fusion -# --use-amd - - - - - -GPT_ARGS=" - --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ - --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ - --depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \ - --pipeline-model-parallel-size ${PIPE_PARR} \ - --num-layers ${NUM_LAYERS} \ - --hidden-size ${HIDDEN_SIZE} \ - --num-attention-heads ${NUM_HEADS} \ - --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ - --seq-length ${SEQUENCE_LENGTH} \ - --max-position-embeddings ${SEQUENCE_LENGTH} \ - --micro-batch-size ${MICRO_BATCH_SIZE} \ - --global-batch-size ${GLOBAL_BATCH_SIZE} \ - --lr 4.0e-4 \ - --train-iters ${TRAIN_ITERS} \ - --lr-decay-iters ${TRAIN_ITERS} \ - --lr-decay-style cosine \ - --min-lr 4.0e-5 \ - --weight-decay 1e-1 \ - --lr-warmup-iters 2000 \ - --clip-grad 1.0 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --bf16 \ - --no-gradient-accumulation-fusion \ - --use-amd \ - --recompute-granularity full \ - --recompute-method uniform \ - --recompute-num-layers 1 \ - --use-flash-attn \ - --swiglu \ - --use-rotary-position-embeddings \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups ${NUM_QUERY_GROUPS} -" - -## AxoNN specific args for communication optimizations -# these do not affect the ML science -if [[ $OVERLAP == "True" ]] -then - GPT_ARGS="${GPT_ARGS} \ - --overlap-axonn-comm \ - --overlap-axonn-reduce-scatter \ - --overlap-axonn-all-gather\ - --num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}" -fi - -# --lit-gpt-data-path - is pointing to your dataset -# currently both train and val splits are taken fron --data-path -# the --custom-dataloader argument bypasses megatron's dataloaders -# --num-workers 0 - disables multiprocesses dataloading -# which can hang jobs at scale - -DATA_ARGS=" - --lit-gpt-data-path $DATAPATH \ - --custom-dataloader \ - --num-workers 0 -" - -# these args are for megatron dataloaders -# these are not needed for litgpt, but not passing them -# might give you errors -# THESE DO NOTHING -REDUNDANT_DATA_ARGS=" - --vocab-file $VOCAB_FILE \ - --merge-file $MERGE_FILE \ - --split 949,50,1 \ -" - -DATA_ARGS="${DATA_ARGS} ${REDUNDANT_DATA_ARGS}" - -# --eval-interval 1000 - do validation after every 1000 arguments -# --eval-iters 100 - do validation for 100 iterations -# --save-interval 1000 - save the model after every 1000 iterations -# --log-interval 1 - print iteration lossees after every 1 iteration -OUTPUT_ARGS=" - --log-interval 1 \ - --save-interval 1000 \ - --eval-interval 1000 \ - --eval-iters 100 \ -" - -SCRIPT="python -u pretrain_gpt.py \ - $GPT_ARGS \ - $DATA_ARGS \ - $OUTPUT_ARGS \ - --distributed-backend nccl \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH -" - - -export OMP_NUM_THREADS=7 -run_cmd="srun -N ${NNODES} -n ${GPUS} -c7 --gpus-per-task=1 --gpu-bind=closest ./examples/get_rank_from_slurm.sh ${SCRIPT}" - -echo ${run_cmd} -eval ${run_cmd} -set +x \ No newline at end of file diff --git a/prepare_venv.sh b/prepare_venv.sh new file mode 100755 index 0000000000..80b2aaf17d --- /dev/null +++ b/prepare_venv.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +userid=$(whoami) + +if [ ! -d /mnt/bb/${userid}/axonn_venv ]; then + cp /lustre/orion/scratch/${userid}/csc547/axonn_venv.tar.gz /mnt/bb/${userid}/ + cd /mnt/bb/${userid}/ + tar -xf axonn_venv.tar.gz +fi From 36f4bdb6ee7101ad6d72b05670ee0e6b8cdbc7f1 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 9 Feb 2024 18:19:58 -0500 Subject: [PATCH 20/25] disable bias and untie weights --- examples/run_axonn_amd_tinyllama.sh | 26 +++++++----------- megatron/core/tensor_parallel/layers.py | 11 ++++++-- megatron/model/gpt_model.py | 1 - megatron/model/transformer.py | 36 +++++++++++++++++++------ 4 files changed, 47 insertions(+), 27 deletions(-) diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index a8552082c3..4c54573c7e 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -50,7 +50,6 @@ NNODES=$SLURM_JOB_NUM_NODES GPUS_PER_NODE=8 ## change as per your machine GPUS=$(( NNODES * GPUS_PER_NODE )) - userid=$(whoami) # These are the two things you need to change as per your setup # 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed @@ -59,9 +58,6 @@ export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/ccs/home/$userid/aws-ofi-rccl/build/ # 2. Make PYTHONPATH point to your local clone of litgpt export PYTHONPATH="$PYTHONPATH:/lustre/orion/scratch/$userid/csc547/lit-gpt-dev" -# The rest of the script should work as it is -echo "This TinyLLAMA script will work for <=512 GPUs." - # This blob is setting up my python venv, ignore for conda builds echo "moving environment to burst buffer" @@ -69,7 +65,7 @@ echo "moving environment to burst buffer" srun -N $NNODES --ntasks-per-node=1 prepare_venv.sh ## delete old symbolic link rm -rf ~/axonn_venv -## craete new symbolic link +## create new symbolic link ln -s /mnt/bb/ssingh37/axonn_venv ~/axonn_venv module load PrgEnv-cray module load cray-python/3.9.13.1 @@ -98,18 +94,14 @@ DATADIR="/lustre/orion/csc569/proj-shared/language_datasets/" DATASET="spj_star_combined_full_tinyllama_tokd" DATAPATH="$DATADIR/$DATASET" - -########## TODO: FIX TO TINY LLAMA VOCAB ############# TOKENIZER_DIR="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/llama-tokenizer" TOKENIZER_MODEL="${TOKENIZER_DIR}/tokenizer.model" - # we will save and load model checkpoints here # if these are non-empty training will restart from the latest checkpoint here # else training will start from scratch CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" - # tiny-llama1.1B architecture shapes # https://github.com/azshue/lit-gpt-dev/blob/tiny-llama/lit_gpt/config.py NUM_LAYERS=22 @@ -119,7 +111,7 @@ FFN_HIDDEN_SIZE=5632 NUM_QUERY_GROUPS=4 # batch size, seq length, and iterations -GLOBAL_BATCH_SIZE=32 ## Neel: 2048x2048 = 4M per batch +GLOBAL_BATCH_SIZE=2048 ## Neel: 2048x2048 = 4M per batch SEQUENCE_LENGTH=2048 TOKENS_IN_BILLIONS=1000 ### Neel: Changed 1T ##### TRAIN_ITERS=$(( TOKENS_IN_BILLIONS * 1000000000 / GLOBAL_BATCH_SIZE / SEQUENCE_LENGTH + 100 )) @@ -129,12 +121,11 @@ echo "Number of training iterations : ${TRAIN_ITERS}" ## These do not affect the science ROW_TENSOR_PARR=1 COLUMN_TENSOR_PARR=1 -DEPTH_TENSOR_PARR=2 +DEPTH_TENSOR_PARR=1 PIPE_PARR=1 -CACHE_LAYERS=22 +CACHE_LAYERS=0 OVERLAP=True - ## DERIVED ARGUMENTS (ignore) MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) DP=$(( GPUS / MP )) @@ -146,12 +137,13 @@ MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) # --num-query-groups - number of query groups for group query attention # --normalization RMSNorm - switch from layernorm to RMSNorm (someone confirm?) # --use-rotary-position-embeddings - use RoPE embeddings instead of learned position embeddings -# +# --untie-embeddings-and-output-weights - untie embedding and last layer weights +# --disable-bias-linear - disables bias in all nn.linear layers + # The following args disable features not compatible with AMD # --no-gradient-accumulation-fusion # --use-amd - GPT_ARGS=" --row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \ --column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \ @@ -186,7 +178,9 @@ GPT_ARGS=" --use-rotary-position-embeddings \ --normalization RMSNorm \ --group-query-attention \ - --num-query-groups ${NUM_QUERY_GROUPS} + --num-query-groups ${NUM_QUERY_GROUPS} \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear " ## AxoNN specific args for communication optimizations diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index b09879bd83..052b014f21 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -582,6 +582,8 @@ def __init__( keep_master_weight_for_test=False, skip_bias_add=False, skip_weight_param_allocation: bool = False, + for_embedding_and_clf_layer: bool = False + ): super(ColumnParallelLinear, self).__init__() @@ -590,10 +592,11 @@ def __init__( self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() + world_size = get_tensor_model_parallel_world_size(for_embedding_and_clf_layer=for_embedding_and_clf_layer) self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add self.config = config + self.for_embedding_and_clf_layer = for_embedding_and_clf_layer # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -601,6 +604,7 @@ def __init__( # Initialize weight. if not skip_weight_param_allocation: if config.use_cpu_initialization: + raise NotImplementedError self.weight = Parameter( torch.empty( self.output_size_per_partition, self.input_size, dtype=config.params_dtype @@ -628,7 +632,8 @@ def __init__( ) if config.perform_initialization: _initialize_affine_weight_gpu( - self.weight, init_method, partition_dim=0, stride=stride + self.weight, init_method, partition_dim=0, stride=stride, + for_embedding_and_clf_layer=self.for_embedding_and_clf_layer ) else: self.weight = None @@ -724,6 +729,7 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): input_parallel = copy_to_tensor_model_parallel_region(input_) # Matrix multiply. if not weight.requires_grad: + raise NotImplementedError self._forward_impl = linear_with_frozen_weight else: self._forward_impl = linear_with_grad_accumulation_and_async_allreduce @@ -734,6 +740,7 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, sequence_parallel=self.sequence_parallel, + for_embedding_and_clf_layer=self.for_embedding_and_clf_layer ) if self.gather_output: # All-gather across the partitions. diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index bb1af9b9e7..20d7bfde4c 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -58,7 +58,6 @@ def __init__(self, self.post_process = post_process self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights - assert not self.untie_embeddings_and_output_weights, "Megatron-AxoNN doesn't support untied embedding yet" self.language_model, self._language_model_key = get_language_model( config=config, diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 632507bd1b..2fc33ba12c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -95,6 +95,7 @@ def __init__(self, config, layer_number): out_features=ffn_hidden_size, skip_bias_add=True, init_method=config.init_method, + bias = self.add_bias ) self.bias_gelu_fusion = False @@ -125,13 +126,19 @@ def squared_relu(x): skip_bias_add=True, init_method=config.output_layer_init_method, transpose=True, + bias = self.add_bias ) def forward(self, hidden_states): torch.cuda.nvtx.range_push(f"MLP Block") # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states, scatter_input=False, gather_output=False, - cache_weights_in_all_gather = self.cache_weights_in_all_gather) + output = self.dense_h_to_4h(hidden_states, scatter_input=False, gather_output=False, + cache_weights_in_all_gather = self.cache_weights_in_all_gather) + + if isinstance(output, tuple): + intermediate_parallel, bias_parallel = output + else: + intermediate_parallel, bias_parallel = output, None if self.bias_gelu_fusion: assert self.add_bias is True @@ -143,8 +150,14 @@ def forward(self, hidden_states): intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel, scatter_input=False, gather_output=False, + output = self.dense_4h_to_h(intermediate_parallel, scatter_input=False, gather_output=False, cache_weights_in_all_gather = self.cache_weights_in_all_gather) + + if isinstance(output, tuple): + output, output_bias = output + else: + output, output_bias = output, None + torch.cuda.nvtx.range_pop() return output, output_bias @@ -463,8 +476,8 @@ def __init__(self, config, layer_number, self.query_key_value = Linear( in_features=config.hidden_size, out_features=query_projection_size + 2 * kv_projection_size, - skip_bias_add=True, - init_method=config.init_method) + init_method=config.init_method, + bias=args.add_bias_linear) else: raise NotImplementedError assert attention_type == AttnType.cross_attn @@ -504,7 +517,9 @@ def __init__(self, config, layer_number, out_features=config.hidden_size, skip_bias_add=True, init_method=config.output_layer_init_method, - transpose=True) + transpose=True, + bias=args.add_bias_linear + ) def _checkpointed_attention_forward(self, query_layer, key_layer, @@ -573,7 +588,7 @@ def forward(self, hidden_states, attention_mask, if self.attention_type == AttnType.self_attn: # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states, scatter_input=False, gather_output=False, + mixed_x_layer = self.query_key_value(hidden_states, scatter_input=False, gather_output=False, cache_weights_in_all_gather=self.cache_weights_in_all_gather) # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] @@ -720,7 +735,12 @@ def forward(self, hidden_states, attention_mask, # Output. [sq, b, h] # ================= - output, bias = self.dense(context_layer, scatter_input=False, gather_output=False, cache_weights_in_all_gather=self.cache_weights_in_all_gather) + output = self.dense(context_layer, scatter_input=False, gather_output=False, cache_weights_in_all_gather=self.cache_weights_in_all_gather) + if isinstance(output, tuple): + output, bias = output + else: + output, bias = output, None + torch.cuda.nvtx.range_pop() return output, bias From 7da8f85335bd39d9baaad10fb20f44c9f73deabb Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 12 Feb 2024 08:00:31 -0500 Subject: [PATCH 21/25] add tinyllama init method and make pytorch optim default --- megatron/arguments.py | 10 ++++++++++ megatron/optimizer/__init__.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index eb6aae38f9..ef92ffd273 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -422,6 +422,12 @@ def core_transformer_config_from_args(args): if args.init_method_xavier_uniform: kw_args['init_method'] = torch.nn.init.xavier_uniform_ kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_ + elif args.init_method_tiny_llama: + from megatron.core.utils import init_method_normal + import math + kw_args['init_method'] = init_method_normal(math.sqrt(2.0 / 5 / args.hidden_size)) + kw_args['output_layer_init_method'] = init_method_normal(1 / math.sqrt(args.hidden_size) / args.num_layers ) + if args.group_query_attention: kw_args['num_query_groups'] = args.num_query_groups else: @@ -699,6 +705,8 @@ def _add_regularization_args(parser): 'numerical stability') group.add_argument('--sgd-momentum', type=float, default=0.9, help='Momentum factor for sgd') + group.add_argument('--use-apex-adam', action='store_true', default=False, + help="Use Apex's implementation of Adam") return parser @@ -861,6 +869,8 @@ def _add_initialization_args(parser): 'distribution used for weight initialization.') group.add_argument('--init-method-xavier-uniform', action='store_true', help='Enable Xavier uniform parameter initialization') + group.add_argument('--init-method-tiny-llama', action='store_true', + help='Enable Tiny LLaMA based initialization') return parser diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 33744a2f3a..29fb8e0760 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from apex.optimizers import FusedAdam as Adam +from apex.optimizers import FusedAdam as ApexAdam +from torch.optim import AdamW as Adam from apex.optimizers import FusedSGD as SGD from megatron import get_args @@ -72,7 +73,14 @@ def get_megatron_optimizer(model, lr_mult) if args.optimizer == 'adam': - optimizer = Adam(param_groups, + if args.use_apex_adam: + optimizer = ApexAdam(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps) + else: + optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), From 715aef8ec4c7c922f968fa79b948da17d2a7e1d2 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 12 Feb 2024 19:35:29 -0500 Subject: [PATCH 22/25] add comment about seed --- custom_litgpt_dataloader/data_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/custom_litgpt_dataloader/data_util.py b/custom_litgpt_dataloader/data_util.py index 0412110b9e..36da5ae70c 100644 --- a/custom_litgpt_dataloader/data_util.py +++ b/custom_litgpt_dataloader/data_util.py @@ -48,6 +48,8 @@ def create_dataloader( sum_weights = sum(weights) weights = [el / sum_weights for el in weights] + #having different seeds here is important such that each batch has tokens + #from all data mixtures. combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) From e610c94ccbaaa9c8f9c4a4b008e2352f88939d1b Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 12 Feb 2024 19:36:06 -0500 Subject: [PATCH 23/25] make dataloader exactly like neel's litgpt --- pretrain_gpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index b7ea61d8e0..e0cb1f8901 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -148,6 +148,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): block_size= args.seq_length, train_data_dir = args.lit_gpt_data_path, val_data_dir = args.lit_gpt_data_path, + seed = args.seed + ax.config.data_parallel_rank ) # these flags are set within megatron in # the OG dataloader @@ -168,7 +169,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): next(valid_iterator) end = time.time() print_rank_0(f"Time for rewinding the dataloader on rank 0 = {end-start:.2f} s") - + return train_iterator, valid_iterator else: print_rank_0('> building train, validation, and test datasets ' From 0881c29a174341b1daa796a24ed6880b2f3ae202 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 16 Feb 2024 12:18:53 -0500 Subject: [PATCH 24/25] option to disable activation checkpointing --- examples/run_axonn_amd_tinyllama.sh | 21 +++++++++++++++------ megatron/model/transformer.py | 1 - 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 4c54573c7e..473555b21a 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -100,7 +100,7 @@ TOKENIZER_MODEL="${TOKENIZER_DIR}/tokenizer.model" # we will save and load model checkpoints here # if these are non-empty training will restart from the latest checkpoint here # else training will start from scratch -CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints" +CHECKPOINT_PATH="/lustre/orion/csc569/proj-shared/megatron-axonn-tiny-llama-1.1b/checkpoints/dataloader_correction" # tiny-llama1.1B architecture shapes # https://github.com/azshue/lit-gpt-dev/blob/tiny-llama/lit_gpt/config.py @@ -126,10 +126,14 @@ PIPE_PARR=1 CACHE_LAYERS=0 OVERLAP=True + +GRAD_ACC=2 +GRADIENT_CHECKPOINT=False + ## DERIVED ARGUMENTS (ignore) MP=$(( ROW_TENSOR_PARR * COLUMN_TENSOR_PARR * DEPTH_TENSOR_PARR )) DP=$(( GPUS / MP )) -MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP )) +MICRO_BATCH_SIZE=$(( GLOBAL_BATCH_SIZE / DP / GRAD_ACC )) # The following args enable LLaMA # --swiglu makes ParallelMLP equivalent to LLAMAMLP @@ -170,9 +174,6 @@ GPT_ARGS=" --bf16 \ --no-gradient-accumulation-fusion \ --use-amd \ - --recompute-granularity full \ - --recompute-method uniform \ - --recompute-num-layers 1 \ --use-flash-attn \ --swiglu \ --use-rotary-position-embeddings \ @@ -180,9 +181,17 @@ GPT_ARGS=" --group-query-attention \ --num-query-groups ${NUM_QUERY_GROUPS} \ --untie-embeddings-and-output-weights \ - --disable-bias-linear + --disable-bias-linear \ + --use-apex-adam " +if [[ $GRADIENT_CHECKPOINT == "True" ]] +then + GPT_ARGS="${GPT_ARGS} --recompute-granularity full \ + --recompute-method uniform \ + --recompute-num-layers 1" +fi + ## AxoNN specific args for communication optimizations # these do not affect the ML science if [[ $OVERLAP == "True" ]] diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 2fc33ba12c..72b0e1ec85 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1695,7 +1695,6 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb, is_first_microbatch) else: - raise NotImplementedError forward_kwargs = { 'encoder_output': encoder_output, 'enc_dec_attn_mask': enc_dec_attn_mask, From 866e58ea4464ca49b0473ec13467413f50fa8a8d Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sat, 17 Feb 2024 11:13:03 -0500 Subject: [PATCH 25/25] latest changes --- examples/run_axonn_amd_tinyllama.sh | 5 ++++- pretrain_gpt.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/run_axonn_amd_tinyllama.sh b/examples/run_axonn_amd_tinyllama.sh index 473555b21a..e7a491b23d 100755 --- a/examples/run_axonn_amd_tinyllama.sh +++ b/examples/run_axonn_amd_tinyllama.sh @@ -182,7 +182,10 @@ GPT_ARGS=" --num-query-groups ${NUM_QUERY_GROUPS} \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ - --use-apex-adam + --use-apex-adam \ + --seed 78965 \ + --attention-dropout 0 \ + --hidden-dropout 0 " if [[ $GRADIENT_CHECKPOINT == "True" ]] diff --git a/pretrain_gpt.py b/pretrain_gpt.py index e0cb1f8901..75757ae5e5 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -148,7 +148,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): block_size= args.seq_length, train_data_dir = args.lit_gpt_data_path, val_data_dir = args.lit_gpt_data_path, - seed = args.seed + ax.config.data_parallel_rank + seed = 12345 ) # these flags are set within megatron in # the OG dataloader