Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pipelining #19

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
85 changes: 68 additions & 17 deletions examples/run_axonn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ NNODES=$SLURM_JOB_NUM_NODES
GPUS=$(( NNODES * 4 ))
export MASTER_ADDR=$(hostname)
export MASTER_PORT=29500
export CUDA_DEVICE_MAX_CONNECTIONS=1
#export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_NET_GDR_LEVEL=PHB
export CUDA_DEVICE_MAX_CONNECTIONS=1
#export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_VISIBLE_DEVICES=3,2,1,0
export NCCL_CROSS_NIC=1
export NCCL_SOCKET_IFNAME=hsn
Expand All @@ -29,25 +29,35 @@ MERGE_FILE="${DATA_DIR}/gpt2-merges.txt"
DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document"

## ARCHITECTURE DETAILS
NUM_LAYERS=24
HIDDEN_SIZE=1024
NUM_HEADS=16
NUM_LAYERS=32
NUM_HEADS=56
HIDDEN_SIZE=7168 #$(( 128 * NUM_HEADS ))

## PARALLELISM DETAILS
COLUMN_TENSOR_PARR=1
ROW_TENSOR_PARR=1
DEPTH_TENSOR_PARR=4
DEPTH_TENSOR_PARR=16
PIPE_PARR=1
OVERLAP=True

NSYS_PROFILE=False
PROF_OUTPUT="test_red_scat_without_max_con"

## BATCH SIZES
MICRO_BATCH_SIZE=16
GLOBAL_BATCH_SIZE=16
SEQUENCE_LENGTH=1024
MICRO_BATCH_SIZE=64
GLOBAL_BATCH_SIZE=64
SEQUENCE_LENGTH=2048
TRAIN_ITERS=10

#OUTPUT_FOLDER="./logs/seq_len"
#OUTPUT_FILE="${OUTPUT_FOLDER}/TP-${COLUMN_TENSOR_PARR}x${ROW_TENSOR_PARR}x${DEPTH_TENSOR_PARR}_PP-${PIPE_PARR}_mbs-${MICRO_BATCH_SIZE}-bs-${GLOBAL_BATCH_SIZE}-overlap-${OVERLAP}-seq-length-${SEQUENCE_LENGTH}"
mkdir -p ${OUTPUT_FOLDER}

GPT_ARGS="
--row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \
--column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \
--depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \
--pipeline-model-parallel-size ${PIPE_PARR} \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
--num-attention-heads ${NUM_HEADS} \
Expand All @@ -56,15 +66,36 @@ GPT_ARGS="
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--lr 0.00015 \
--train-iters 500000 \
--train-iters ${TRAIN_ITERS} \
--lr-decay-iters 320000 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--fp16 \
"
--bf16 \
--use-flash-attn \
--recompute-granularity full \
--recompute-method uniform \
--recompute-num-layers 1 \
--no-gradient-accumulation-fusion \
--untie-embeddings-and-output-weights \
--no-async-tensor-model-parallel-allreduce
" # only set for pipelineing


if [[ $OVERLAP == "True" ]]
then
GPT_ARGS="${GPT_ARGS} \
--overlap-axonn-comm \
--overlap-axonn-reduce-scatter \
--overlap-axonn-all-gather \
--cache-weights-in-depth-tensor-parallelism
"

fi



DATA_ARGS="
--data-path $DATA_PATH \
Expand All @@ -77,20 +108,40 @@ OUTPUT_ARGS="
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10
--eval-iters 1
"



SCRIPT="python -u pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
"


run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ${SCRIPT}"
if [[ ${NSYS_PROFILE} == "True" ]]
then
echo "profiling with nsys"
SCRIPT="nsys profile -s none \
-t nvtx,cuda -o ${PROF_OUTPUT}.qdrep \
--force-overwrite=true \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
${SCRIPT} \
--profile-step-start 5 \
--profile-step-end 10 \
--profile
"
fi


#--save $CHECKPOINT_PATH \
# --load $CHECKPOINT_PATH

export MPICH_GPU_SUPPORT_ENABLED=1
export CRAY_ACCEL_TARGET=nvidia80
run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ${SCRIPT}" #| tee ${OUTPUT_FILE}"

echo ${run_cmd}
eval ${run_cmd}
Expand Down
106 changes: 106 additions & 0 deletions examples/run_axonn_350M.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/bin/bash

# Runs the "345M" parameter model

export CUDA_DEVICE_MAX_CONNECTIONS=1


NNODES=$SLURM_JOB_NUM_NODES
GPUS=$(( NNODES * 4 ))
export MASTER_ADDR=$(hostname)
export MASTER_PORT=29500
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_NET_GDR_LEVEL=PHB
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_VISIBLE_DEVICES=3,2,1,0
export NCCL_CROSS_NIC=1
export NCCL_SOCKET_IFNAME=hsn
export NCCL_NET="AWS Libfabric"
export FI_CXI_RDZV_THRESHOLD=0
export FI_CXI_RDZV_GET_MIN=0
export FI_CXI_OFLOW_BUF_SIZE=1073741824
export FI_CXI_OFLOW_BUF_COUNT=1


DATA_DIR="${SCRATCH}/gpt_data"
CHECKPOINT_PATH="${DATA_DIR}/checkpoints"
VOCAB_FILE="${DATA_DIR}/gpt2-vocab.json"
MERGE_FILE="${DATA_DIR}/gpt2-merges.txt"
DATA_PATH="${DATA_DIR}/BookCorpusDataset_text_document"

## ARCHITECTURE DETAILS
NUM_LAYERS=24
HIDDEN_SIZE=1024
NUM_HEADS=16

## PARALLELISM DETAILS
COLUMN_TENSOR_PARR=1
ROW_TENSOR_PARR=1
DEPTH_TENSOR_PARR=4

## BATCH SIZES
MICRO_BATCH_SIZE=8
GLOBAL_BATCH_SIZE=16
SEQUENCE_LENGTH=1024

OVERLAP="True"

GPT_ARGS="
--row-tensor-model-parallel-size ${ROW_TENSOR_PARR} \
--column-tensor-model-parallel-size ${COLUMN_TENSOR_PARR} \
--depth-tensor-model-parallel-size ${DEPTH_TENSOR_PARR} \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
--num-attention-heads ${NUM_HEADS} \
--seq-length ${SEQUENCE_LENGTH} \
--max-position-embeddings ${SEQUENCE_LENGTH} \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--lr 0.00015 \
--train-iters 500000 \
--lr-decay-iters 320000 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--fp16 \
"


if [[ $OVERLAP == "True" ]]
then
GPT_ARGS="${GPT_ARGS} \
--overlap-axonn-comm \
--cache-weights-in-depth-tensor-parallelism"
fi

DATA_ARGS="
--data-path $DATA_PATH \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--split 949,50,1
"

OUTPUT_ARGS="
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10
"

SCRIPT="python -u pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH
"


run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ${SCRIPT}"

echo ${run_cmd}
eval ${run_cmd}
set +x
9 changes: 9 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,15 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.')
group.add_argument('--depth-tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--overlap-axonn-comm', action='store_true', default=False,
help='Overlap all-reduces in backward pass ofAxoNN\'s tensor parallelism')
group.add_argument('--overlap-axonn-reduce-scatter', action='store_true', default=False,
help='Overlap reduce scatters in backward pass of AxoNN\'s tensor parallelism')
group.add_argument('--cache-weights-in-depth-tensor-parallelism', action='store_true', default=False,
help='cache weights during the first all-gather for a batch and reuse them until optimizer.step()')
group.add_argument('--overlap-axonn-all-gather', action='store_true', default=False,
help='Overlap all-gathers in forward pass of AxoNN\'s tensor parallelism')

group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
Expand Down
78 changes: 56 additions & 22 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type

from axonn.intra_layer import optimize_communication
from megatron import get_args
from contextlib import nullcontext
from functools import partial

# Types
Shape = Union[List[int], torch.Size]

Expand Down Expand Up @@ -324,36 +329,65 @@ def forward_backward_no_pipelining(

forward_data_store = []
input_tensor, output_tensor_grad = None, None

args=get_args()
#ctx = nullcontext()
if args.overlap_axonn_comm:
ctx = partial(optimize_communication,
overlap_all_reduce=True,
overlap_reduce_scatter=args.overlap_axonn_reduce_scatter,
cache_weights=args.cache_weights_in_depth_tensor_parallelism,
overlap_all_gather=args.overlap_axonn_all_gather,
model=model)
else:
ctx = nullcontext

def post_process():
if args.overlap_axonn_reduce_scatter:
for param in model.parameters():
if param.requires_grad:
if param.grad is not None and not param.grad_added_to_main_grad:
param.main_grad.add_(param.grad.data)
param.grad = None

with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
)
with ctx():#axonn.intra_layer.optimize_communication(False):
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
)
if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)

if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
post_process() # need to call this because of the grad hook in megatron-lm

# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
)
with ctx():#axonn.intra_layer.optimize_communication(False):
output_tensor = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
)

if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)

if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
post_process() # need to call this because of the grad hook in megatron-lm

return forward_data_store

Expand Down
18 changes: 10 additions & 8 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,11 @@ def __post_init__(self):
)

if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
pass # this will be caught later
#raise ValueError(
# f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
# f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
#)

if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
Expand All @@ -205,10 +206,11 @@ def __post_init__(self):
self.num_query_groups = self.num_attention_heads

if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
pass # this will be caught later
#raise ValueError(
# f"num_query_groups ({self.num_query_groups}) must be a multiple of "
# f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
#)

if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
Expand Down
Loading