diff --git a/ompi/mca/coll/accelerator/coll_accelerator.h b/ompi/mca/coll/accelerator/coll_accelerator.h index b170e38f268..e707d7ec7f2 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator.h +++ b/ompi/mca/coll/accelerator/coll_accelerator.h @@ -1,4 +1,5 @@ /* + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2014 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. @@ -45,6 +46,11 @@ mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, size_t count, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +int mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + mca_coll_base_module_t *module); + int mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, diff --git a/ompi/mca/coll/accelerator/coll_accelerator_module.c b/ompi/mca/coll/accelerator/coll_accelerator_module.c index 4fe1603a8aa..4005f6cdec9 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_module.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_module.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2014-2017 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. @@ -94,6 +95,7 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce; accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce; + accelerator_module->super.coll_reduce_local = mca_coll_accelerator_reduce_local; accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block; if (!OMPI_COMM_IS_INTER(comm)) { accelerator_module->super.coll_scan = mca_coll_accelerator_scan; @@ -141,6 +143,7 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local); ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ @@ -159,6 +162,7 @@ mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce); ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_local); ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block); if (!OMPI_COMM_IS_INTER(comm)) { diff --git a/ompi/mca/coll/accelerator/coll_accelerator_reduce.c b/ompi/mca/coll/accelerator/coll_accelerator_reduce.c index 6b0d3d5d72b..38143317c42 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_reduce.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_reduce.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. @@ -35,7 +36,7 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count, mca_coll_base_module_t *module) { mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; - int rank = ompi_comm_rank(comm); + int rank = (comm == NULL) ? -1 : ompi_comm_rank(comm); ptrdiff_t gap; char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL; size_t bufsize; @@ -70,9 +71,15 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count, rbuf2 = rbuf; /* save away original buffer */ rbuf = rbuf1 - gap; } - rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count, - dtype, op, root, comm, - s->c_coll.coll_reduce_module); + + if ((comm == NULL) && (root == -1)) { + ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype); + rc = OMPI_SUCCESS; + } else { + rc = s->c_coll.coll_reduce((void *) sbuf, rbuf, count, + dtype, op, root, comm, + s->c_coll.coll_reduce_module); + } if (NULL != sbuf1) { free(sbuf1); @@ -84,3 +91,13 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count, } return rc; } + +int +mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + mca_coll_base_module_t *module) +{ + return mca_coll_accelerator_reduce(sbuf, rbuf, count, dtype, op, -1, NULL, + module); +}