Skip to content

Commit

Permalink
adjust executor python test
Browse files Browse the repository at this point in the history
  • Loading branch information
caiomcbr committed Sep 19, 2024
1 parent 226b44b commit ce3bbd9
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,23 @@ def main(
npkit.init(mscclpp_group.my_rank)
execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path)

cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = cp.zeros(nelems, dtype=dtype)
for i in range(nelems):
sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i]

if "allgather" in execution_plan_name:
cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
buffer = cp.empty(nelems * mscclpp_group.nranks, dtype=dtype)
buffer[:] = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = cp.zeros(nelems, dtype=dtype)
for i in range(nelems):
sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i]
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
expected = buffer
else:
cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
sendbuf = cp.random.random(nelems).astype(dtype)
expected = cp.asnumpy(sendbuf)
expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM)
recvbuf = cp.zeros(nelems, dtype=dtype)
expected = cp.zeros_like(sendbuf, dtype=dtype)
for i in range(mscclpp_group.nranks):
expected += sub_arrays[i]
mscclpp_group.barrier()

executor_func = lambda stream: executor.execute(
Expand Down

0 comments on commit ce3bbd9

Please sign in to comment.