Skip to content

Commit

Permalink
Fix allgather kernel 2 perf bug (#108)
Browse files Browse the repository at this point in the history
Fix #105
  • Loading branch information
Binyang2014 committed Jun 16, 2023
1 parent 6cd8960 commit 8410fcd
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions test/mscclpp-test/allgather_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ __device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int ra
}

__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int worldSize,
int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) {
int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size,
bool flushAfterSignal = true) {
// this allgather algorithm works as follows:
// Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode
// and waits for data from GPU rank (i-1) % nranksPerNode
Expand All @@ -42,7 +43,8 @@ __device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, in
for (int i = 1; i < nranksPerNode; i++) {
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
if ((threadIdx.x % 32) == 0) devChan.putWithSignalAndFlush(offset, size);
if (flushAfterSignal && (threadIdx.x % 32) == 0) devChan.putWithSignalAndFlush(offset, size);
if (!flushAfterSignal && (threadIdx.x % 32) == 0) devChan.putWithSignal(offset, size);
}
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
Expand Down Expand Up @@ -76,37 +78,48 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra
// local allgather
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
localAllGather(devChan, rank, worldSize, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
nelemsPerGPU * sizeof(int));
nelemsPerGPU * sizeof(int), false);
}
// cross-node exchange
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
devChan.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0) devChan.wait();
}

// sync here to make sure IB flush dose not block the CUDA IPC traffic
__syncthreads();
// since all CUDA IPC share the same CUDA stream, only need to flush one of devChans
if ((remoteRank % nranksPerNode == rank % nranksPerNode) ||
(remoteRank / nranksPerNode == rank / nranksPerNode && rank % nranksPerNode == 0)) {
if ((threadIdx.x % 32) == 0) devChan.flush();
}
__syncthreads();

// Step 2
// local allgather
int otherNghr = (rank + nranksPerNode) % worldSize;
if (remoteRank / nranksPerNode == rank / nranksPerNode) {
localAllGather(devChan, rank, worldSize, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int), false);
}

// cross-node exchange
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
devChan.putWithSignalAndFlush(
(rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0) devChan.wait();
}

__syncthreads();
if ((remoteRank % nranksPerNode == rank % nranksPerNode) ||
(remoteRank / nranksPerNode == rank / nranksPerNode && rank % nranksPerNode == 0)) {
if ((threadIdx.x % 32) == 0) devChan.flush();
}
__syncthreads();

// Step 3
Expand Down

0 comments on commit 8410fcd

Please sign in to comment.