Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove requirement to initialize MPI for tensor+data parallelism #20

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/get_rank_from_slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
# select_gpu_device wrapper script
export RANK=${SLURM_PROCID}
exec $*
81 changes: 44 additions & 37 deletions examples/run_axonn.sh
Original file line number Diff line number Diff line change
@@ -1,69 +1,74 @@
#!/bin/bash

# Runs the "345M" parameter model
echo "This trains a 5B parameter model on 64 GPUs of Perlmutter"

export CUDA_DEVICE_MAX_CONNECTIONS=1


NNODES=$SLURM_JOB_NUM_NODES
GPUS=$(( NNODES * 4 ))

export WORLD_SIZE=$GPUS
export MASTER_ADDR=$(hostname)
export MASTER_PORT=29500
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_NET_GDR_LEVEL=PHB
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_VISIBLE_DEVICES=3,2,1,0

export NCCL_NET_GDR_LEVEL=PHB
export NCCL_CROSS_NIC=1
export NCCL_SOCKET_IFNAME=hsn

# these are specific to perlmutter's slingshot-11 network
#
export NCCL_NET="AWS Libfabric"
export FI_CXI_RDZV_THRESHOLD=0
export FI_CXI_RDZV_GET_MIN=0
export FI_CXI_OFLOW_BUF_SIZE=1073741824
export FI_CXI_OFLOW_BUF_COUNT=1


DATA_DIR="${SCRATCH}/gpt_data"
CHECKPOINT_PATH="${DATA_DIR}/checkpoints"
VOCAB_FILE="${DATA_DIR}/gpt2-vocab.json"
MERGE_FILE="${DATA_DIR}/gpt2-merges.txt"
DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document"
DATA_DIR="$SCRATCH/gpt_data"
CHECKPOINT_PATH="$DATA_DIR/checkpoints"
VOCAB_FILE="$DATA_DIR/gpt2-vocab.json"
MERGE_FILE="$DATA_DIR/gpt2-merges.txt"
DATA_PATH="$DATA_DIR/BookCorpusDataset_text_document"

## ARCHITECTURE DETAILS
NUM_LAYERS=30
NUM_HEADS=40
HIDDEN_SIZE=5120
NUM_LAYERS=24
NUM_HEADS=32
HIDDEN_SIZE=4096

## PARALLELISM DETAILS
COLUMN_TENSOR_PARR=1
ROW_TENSOR_PARR=1
DEPTH_TENSOR_PARR=8
DEPTH_TENSOR_PARR=16
PIPE_PARR=1
CACHE_LEVEL=0
CACHE_LAYERS=0
OVERLAP=True

NSYS_PROFILE=False
PROFILE_NAME="test_10B_16x1"

## BATCH SIZES
MICRO_BATCH_SIZE=8
GLOBAL_BATCH_SIZE=16
MICRO_BATCH_SIZE=128
GLOBAL_BATCH_SIZE=512
SEQUENCE_LENGTH=2048
TRAIN_ITERS=10
TRAIN_ITERS=20

GPT_ARGS="
--row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \
--column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \
--depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \
--pipeline-model-parallel-size ${PIPE_PARR} \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
--num-attention-heads ${NUM_HEADS} \
--seq-length ${SEQUENCE_LENGTH} \
--max-position-embeddings ${SEQUENCE_LENGTH} \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--row-tensor-model-parallel-size $ROW_TENSOR_PARR \
--column-tensor-model-parallel-size $COLUMN_TENSOR_PARR \
--depth-tensor-model-parallel-size $DEPTH_TENSOR_PARR \
--pipeline-model-parallel-size $PIPE_PARR \
--num-layers $NUM_LAYERS \
--hidden-size $HIDDEN_SIZE \
--num-attention-heads $NUM_HEADS \
--seq-length $SEQUENCE_LENGTH \
--max-position-embeddings $SEQUENCE_LENGTH \
--micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \
--lr 0.00015 \
--train-iters ${TRAIN_ITERS} \
--train-iters $TRAIN_ITERS \
--lr-decay-iters 320000 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
Expand All @@ -75,14 +80,16 @@ GPT_ARGS="
--recompute-granularity full \
--recompute-method uniform \
--recompute-num-layers 1 \
--layer-caching-level $CACHE_LEVEL

"
if [[ $OVERLAP == "True" ]]
then
GPT_ARGS="${GPT_ARGS} \
GPT_ARGS="$GPT_ARGS \
--overlap-axonn-comm \
--overlap-axonn-reduce-scatter \
--overlap-axonn-all-gather\
--num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather ${CACHE_LAYERS}"
--num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather $CACHE_LAYERS"
fi


Expand All @@ -97,7 +104,7 @@ OUTPUT_ARGS="
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 1
--eval-iters 0
"


Expand All @@ -109,15 +116,15 @@ SCRIPT="python -u pretrain_gpt.py \
--distributed-backend nccl \
"

if [[ ${NSYS_PROFILE} == "True" ]]
if [[ $NSYS_PROFILE == "True" ]]
then
echo "profiling with nsys"
SCRIPT="nsys profile -s none \
-t nvtx,cuda -o ${PROFILE_NAME} \
-t nvtx,cuda -o $PROFILE_NAME \
--force-overwrite=true \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
${SCRIPT} \
$SCRIPT \
--profile-step-start 5 \
--profile-step-end 10 \
--profile
Expand All @@ -128,8 +135,8 @@ fi
#--save $CHECKPOINT_PATH \
# --load $CHECKPOINT_PATH

run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ${SCRIPT}"
run_cmd="srun -C gpu -N $NNODES -n $GPUS -c 32 --cpu-bind=cores --gpus-per-node=4 ./examples/get_rank_from_slurm.sh $SCRIPT"

echo ${run_cmd}
eval ${run_cmd}
echo $run_cmd
eval $run_cmd
set +x
1 change: 1 addition & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,7 @@ def _add_distributed_args(parser):
help='Overlap reduce scatters in backward pass of AxoNN\'s tensor parallelism')
group.add_argument('--num-layers-for-caching-weights-in-depth-tensor-parallel-all-gather', type=int, default=0,
help='number of layers to cache weights during the first all-gather for a batch')
group.add_argument('--layer-caching-level', type=int, default=2)
group.add_argument('--overlap-axonn-all-gather', action='store_true', default=False,
help='Overlap all-gathers in forward pass of AxoNN\'s tensor parallelism')

Expand Down
12 changes: 8 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def squared_relu(x):

def forward(self, hidden_states):
torch.cuda.nvtx.range_push(f"MLP Block")
args = get_args()
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states, scatter_input=False, gather_output=False,
cache_weights_in_all_gather = self.cache_weights_in_all_gather)
cache_weights_in_all_gather = self.cache_weights_in_all_gather and args.layer_caching_level>=1)

if self.bias_gelu_fusion:
assert self.add_bias is True
Expand All @@ -144,7 +145,7 @@ def forward(self, hidden_states):

# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel, scatter_input=False, gather_output=False,
cache_weights_in_all_gather = self.cache_weights_in_all_gather)
cache_weights_in_all_gather = self.cache_weights_in_all_gather and args.layer_caching_level==2)
torch.cuda.nvtx.range_pop()
return output, output_bias

Expand Down Expand Up @@ -547,6 +548,7 @@ def forward(self, hidden_states, attention_mask,
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
args = get_args()
torch.cuda.nvtx.range_push(f"Attention Block")
is_first_step = False
if inference_params:
Expand Down Expand Up @@ -574,7 +576,7 @@ def forward(self, hidden_states, attention_mask,

# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states, scatter_input=False, gather_output=False,
cache_weights_in_all_gather=self.cache_weights_in_all_gather)
cache_weights_in_all_gather=self.cache_weights_in_all_gather and args.layer_caching_level==2)

# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
Expand Down Expand Up @@ -720,7 +722,9 @@ def forward(self, hidden_states, attention_mask,
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer, scatter_input=False, gather_output=False, cache_weights_in_all_gather=self.cache_weights_in_all_gather)
output, bias = self.dense(context_layer, scatter_input=False, gather_output=False,
cache_weights_in_all_gather=self.cache_weights_in_all_gather
and args.layer_caching_level==2)
torch.cuda.nvtx.range_pop()
return output, bias

Expand Down
11 changes: 9 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def pretrain(train_valid_test_dataset_provider,
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
#torch.distributed.all_reduce(start_time_tensor,
# op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
Expand Down Expand Up @@ -652,6 +652,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[nan_iters_key])
log_string += ' theoretical FLOP/s: {:.3f} TFLOP/s | '.format(get_flops(elapsed_time_per_iteration))
log_string += ' model size: {:.3f} B params | '.format(get_params())
log_string += f' using mpi: {is_mpi_init()}| '
curr, peak = get_mem()
log_string += ' memory used by tensors {:.3f} GB ( peak {:.3f} GB)'.format(curr, peak)

Expand All @@ -667,6 +668,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,

return report_memory_flag

def is_mpi_init():
try:
from mpi4py import MPI
except ImportError:
return False
return MPI.Is_initialized()

def get_flops(batch_time):
args = get_args()
Expand Down
7 changes: 6 additions & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
return train_ds, valid_ds, test_ds


def set_device_and_init_torch_dist():
def set_device_and_init_torch_dist_mpi():
from mpi4py import MPI
import os

Expand All @@ -186,6 +186,11 @@ def set_device_and_init_torch_dist():
os.environ["RANK"] = str(world_rank)
os.environ["WORLD_SIZE"] = str(world_size)

def set_device_and_init_torch_dist():
torch.distributed.init_process_group(
backend="nccl",
)


if __name__ == "__main__":
set_device_and_init_torch_dist()
Expand Down