diff --git a/ompi/mca/coll/han/Makefile.am b/ompi/mca/coll/han/Makefile.am index 95ab470dc66..6d0e1a7d6f0 100644 --- a/ompi/mca/coll/han/Makefile.am +++ b/ompi/mca/coll/han/Makefile.am @@ -16,6 +16,7 @@ coll_han.h \ coll_han_trigger.h \ coll_han_algorithms.h \ coll_han_alltoall.c \ +coll_han_alltoallv.c \ coll_han_dynamic.h \ coll_han_dynamic_file.h \ coll_han_barrier.c \ diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index e6746d47dc4..f4a57d119b7 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -199,6 +199,7 @@ typedef struct mca_coll_han_op_module_name_t { mca_coll_han_op_up_low_module_name_t scatter; mca_coll_han_op_up_low_module_name_t scatterv; mca_coll_han_op_up_low_module_name_t alltoall; + mca_coll_han_op_up_low_module_name_t alltoallv; } mca_coll_han_op_module_name_t; /** @@ -260,6 +261,11 @@ typedef struct mca_coll_han_component_t { /* alltoall: parallel stages */ int32_t han_alltoall_pstages; + /* low level module for alltoallv */ + uint32_t han_alltoallv_low_module; + int64_t han_alltoallv_smsc_avg_send_limit; + double han_alltoallv_smsc_noncontig_activation_limit; + /* name of the modules */ mca_coll_han_op_module_name_t han_op_module_name; @@ -286,6 +292,8 @@ typedef struct mca_coll_han_component_t { /* Define maximum dynamic errors printed by rank 0 with a 0 verbosity level */ int max_dynamic_errors; + + opal_free_list_t pack_buffers; } mca_coll_han_component_t; /* @@ -297,6 +305,7 @@ typedef struct mca_coll_han_single_collective_fallback_s union { mca_coll_base_module_alltoall_fn_t alltoall; + mca_coll_base_module_alltoallv_fn_t alltoallv; mca_coll_base_module_allgather_fn_t allgather; mca_coll_base_module_allgatherv_fn_t allgatherv; mca_coll_base_module_allreduce_fn_t allreduce; @@ -319,6 +328,7 @@ typedef struct mca_coll_han_single_collective_fallback_s typedef struct mca_coll_han_collectives_fallback_s { mca_coll_han_single_collective_fallback_t alltoall; + mca_coll_han_single_collective_fallback_t alltoallv; mca_coll_han_single_collective_fallback_t allgather; mca_coll_han_single_collective_fallback_t allgatherv; mca_coll_han_single_collective_fallback_t allreduce; @@ -384,6 +394,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); #define previous_alltoall fallback.alltoall.alltoall #define previous_alltoall_module fallback.alltoall.module +#define previous_alltoallv fallback.alltoallv.alltoallv +#define previous_alltoallv_module fallback.alltoallv.module + #define previous_allgather fallback.allgather.allgather #define previous_allgather_module fallback.allgather.module @@ -440,6 +453,7 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); HAN_UNINSTALL_COLL_API(COMM, HANM, allgather); \ HAN_UNINSTALL_COLL_API(COMM, HANM, allgatherv); \ HAN_UNINSTALL_COLL_API(COMM, HANM, alltoall); \ + HAN_UNINSTALL_COLL_API(COMM, HANM, alltoallv); \ han_module->enabled = false; /* entire module set to pass-through from now on */ \ } while(0) @@ -503,6 +517,9 @@ int mca_coll_han_alltoall_intra_dynamic(ALLTOALL_BASE_ARGS, mca_coll_base_module_t *module); int +mca_coll_han_alltoallv_intra_dynamic(ALLTOALLV_BASE_ARGS, + mca_coll_base_module_t *module); +int mca_coll_han_allgather_intra_dynamic(ALLGATHER_BASE_ARGS, mca_coll_base_module_t *module); int @@ -566,4 +583,7 @@ static inline struct mca_smsc_endpoint_t *mca_coll_han_get_smsc_endpoint (struct return (struct mca_smsc_endpoint_t *) proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_SMSC]; } +#define COLL_HAN_PACKBUF_PAYLOAD_BYTES (128*1024) + + #endif /* MCA_COLL_HAN_EXPORT_H */ diff --git a/ompi/mca/coll/han/coll_han_algorithms.c b/ompi/mca/coll/han/coll_han_algorithms.c index b113bd9ac07..1c821a3abfe 100644 --- a/ompi/mca/coll/han/coll_han_algorithms.c +++ b/ompi/mca/coll/han/coll_han_algorithms.c @@ -82,6 +82,10 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] = {"smsc", (fnptr_t)&mca_coll_han_alltoall_using_smsc}, // 2-level { 0 } }, + [ALLTOALLV] = (mca_coll_han_algorithm_value_t[]){ + {"smsc", (fnptr_t)&mca_coll_han_alltoallv_using_smsc}, // 2-level + { 0 } + }, }; int diff --git a/ompi/mca/coll/han/coll_han_algorithms.h b/ompi/mca/coll/han/coll_han_algorithms.h index 9889e5b644d..af486669f60 100644 --- a/ompi/mca/coll/han/coll_han_algorithms.h +++ b/ompi/mca/coll/han/coll_han_algorithms.h @@ -214,5 +214,10 @@ int mca_coll_han_alltoall_using_smsc(ALLTOALL_BASE_ARGS, mca_coll_base_module_t *module); +/* Alltoallv */ +int +mca_coll_han_alltoallv_using_smsc(ALLTOALLV_BASE_ARGS, + mca_coll_base_module_t *module); + #endif diff --git a/ompi/mca/coll/han/coll_han_alltoallv.c b/ompi/mca/coll/han/coll_han_alltoallv.c new file mode 100644 index 00000000000..564cbb16b06 --- /dev/null +++ b/ompi/mca/coll/han/coll_han_alltoallv.c @@ -0,0 +1,856 @@ +/* + * Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** + * @file + * + * This file contains the hierarchical implementations of alltoallv. + * + * mca_coll_han_alltoallv_using_smsc: + * This algorithm relies on SMSC and specifically XPMEM because of + * the need to direct-map the memory. + * + * Each rank on one host is assigned a single + * partner on a remote host and vice versa. Then the rank collects all + * the data its partner will need to receive from its host, and sends it + * in one large send, and likewise receives its data in one large recv, + * then cycles to the next host. + */ + +#include "coll_han.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/pml/pml.h" +#include "coll_han_trigger.h" +#include "opal/mca/smsc/smsc.h" +#include "opal/mca/rcache/rcache.h" +#include "opal/datatype/opal_datatype.h" +#include "opal/datatype/opal_datatype_internal.h" +#include "ompi/mca/osc/base/base.h" + +/* Who is the given ranks partner during the exchange? + This function will require rounds comm_size-many rounds, and your partner + will select you in the same round which you select that partner. */ +static inline int ring_partner_no_skip(int rank, int round, int comm_size) { + /* make sure ring_partner is positive: make argument to modulo > 0 with +comm_size.*/ + return (comm_size + round - rank) % comm_size; +} + +/* Who is the given ranks partner during the exchange? + This function will require rounds comm_size-many rounds, and does + self-exchange last. */ +static inline int ring_partner(int rank, int round, int comm_size) { + round = round % comm_size; + if (round == comm_size - 1) { + /* last round: self-exchange */ + return rank; + } + int self_round = (2*rank) % comm_size; + if ( round < self_round ) + return ring_partner_no_skip(rank, round, comm_size); + else + return ring_partner_no_skip(rank, round+1, comm_size); +} + +struct peer_data { + const void *sbuf; // mmapped: buf1 + void *rbuf; // mmapped: buf2 + struct peer_counts *counts; // mmapped: buf3 + opal_datatype_t *sendtype; // deserialized from buf3, local copy + opal_datatype_t *recvtype; // deserialized from buf3, local copy + void *map_ctx[3]; +}; + +struct peer_counts { + size_t scount; + size_t sdispl; + size_t rcount; + size_t rdispl; +}; + +struct gathered_data { + size_t stype_serialized_length; + size_t rtype_serialized_length; + size_t sbuf_length; + size_t rbuf_length; + void *sbuf; + void *rbuf; + void *serialization_buffer; +}; + +/* Serialize the datatype into the buffer and return buffer length. + If buf is NULL, just return the length of the required buffer. */ +static size_t ddt_pack_datatype(opal_datatype_t* type, uint8_t* buf) +{ + size_t length = sizeof(opal_datatype_t) - offsetof(opal_datatype_t, flags); + size_t n_copy = length; + bool count_only = buf == NULL; + if (!count_only) { + memcpy(buf, &type->flags, length); + } + buf += length; + + if( type->opt_desc.used ) { + /* we are losing the non optimized description of the datatype, + * but it is only useful for dumping the datatype description. + */ + n_copy = (1 + type->opt_desc.used) * sizeof(dt_elem_desc_t); + if (!count_only) { + memcpy(buf, type->opt_desc.desc, n_copy); + buf += n_copy; + } + length += n_copy; + } else { + n_copy = (1 + type->desc.used) * sizeof(dt_elem_desc_t); + if (!count_only) { + memcpy(buf, type->desc.desc, n_copy); + buf += n_copy; + } + length += n_copy; + } + /* The following is not necessary and in fact, non-function in homogenous configurations.*/ + // if (type->ptypes) { + // n_copy = OPAL_DATATYPE_MAX_SUPPORTED * sizeof(size_t); + // } else { + // n_copy = 0; + // } + // if (!count_only) { + // memcpy(buf, type->ptypes, n_copy); + // } + // length += n_copy; + return length; +} + +static size_t ddt_unpack_datatype(opal_datatype_t* type, uint8_t* buf) +{ + OBJ_CONSTRUCT(type, opal_datatype_t); + size_t length = sizeof(opal_datatype_t) - offsetof(opal_datatype_t, flags); + memcpy(&type->flags, buf, length); + buf += length; + size_t nbytes_copy = (1+type->opt_desc.used) * sizeof(dt_elem_desc_t); + type->opt_desc.desc = (dt_elem_desc_t*)malloc(nbytes_copy); + memcpy(type->opt_desc.desc, buf, nbytes_copy); + length += nbytes_copy; + type->desc = type->opt_desc; + buf += nbytes_copy; + type->ptypes = NULL; + return length; +} + +/* basic implementation: send all buffers without packing keeping a limited number in flight. */ +static inline int alltoallv_sendrecv_w_direct_for_debugging( + void **send_from_addrs, + size_t *send_counts, + opal_datatype_t **send_types, + int jrank_sendto, + int ntypes_send, + void **recv_to_addrs, + size_t *recv_counts, + opal_datatype_t **recv_types, + int jrank_recvfrom, + int ntypes_recv, + struct ompi_communicator_t *comm) { + + + const int MAX_BUF_COUNT=8; + int nreqs = MAX_BUF_COUNT; + ompi_request_t *requests[MAX_BUF_COUNT]; + + int jfirst_sendreq = nreqs/2 + nreqs%2; + + int jreq; + ompi_datatype_t *yuck_ompi_dtype_from_opal; + int rc = 0; + int jloop; + + int jrecvs_posted = 0; + int jrecvs_completed = 0; + int jsends_posted = 0; + int jsends_completed = 0; + + for (jloop = 0; ; jloop++) { + int have_completion; + + if (jsends_completed == ntypes_send && jrecvs_completed == ntypes_recv ) + break; + + if (jloop < nreqs){ + jreq = jloop; + have_completion = 0; + } else { + have_completion = 1; + ompi_request_wait_any( nreqs, requests, &jreq, MPI_STATUS_IGNORE ); + } + int ii_send_req = jreq >= jfirst_sendreq; + if (have_completion) { + if (ii_send_req) + jsends_completed++; + else + jrecvs_completed++; + } + + requests[jreq] = &ompi_request_null.request; + if (ii_send_req && jsends_posted < ntypes_send) { + rc = ompi_datatype_create_contiguous( 1, (ompi_datatype_t*)send_types[jsends_posted], &yuck_ompi_dtype_from_opal ); + ompi_datatype_commit(&yuck_ompi_dtype_from_opal); + MCA_PML_CALL(isend + (send_from_addrs[jsends_posted], (int)send_counts[jsends_posted], yuck_ompi_dtype_from_opal, jrank_sendto, + MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, + comm, &requests[jreq])); + ompi_datatype_destroy( &yuck_ompi_dtype_from_opal ); + jsends_posted++; + } + if (!ii_send_req && jrecvs_posted < ntypes_recv ) { + rc = ompi_datatype_create_contiguous( 1, (ompi_datatype_t*)recv_types[jrecvs_posted], &yuck_ompi_dtype_from_opal ); + ompi_datatype_commit(&yuck_ompi_dtype_from_opal); + MCA_PML_CALL(irecv + (recv_to_addrs[jrecvs_posted], (int)recv_counts[jrecvs_posted], yuck_ompi_dtype_from_opal, jrank_recvfrom, + MCA_COLL_BASE_TAG_ALLTOALLV, + comm, &requests[jreq])); + ompi_datatype_destroy( &yuck_ompi_dtype_from_opal ); + jrecvs_posted++; + } + + + if (rc) { break; }; + } + return rc; +} + +static int alltoallv_sendrecv_w( + void **send_from_addrs, + size_t *send_counts, + opal_datatype_t **send_types, + int jrank_sendto, + int ntypes_send, + void **recv_to_addrs, + size_t *recv_counts, + opal_datatype_t **recv_types, + int jrank_recvfrom, + int ntypes_recv, + struct ompi_communicator_t *comm) { + + uint32_t iov_count = 1; + struct iovec iov; + + + + const int MAX_BUF_COUNT=8; + ompi_request_t *requests[MAX_BUF_COUNT]; + opal_free_list_item_t *buf_items[MAX_BUF_COUNT]; + + size_t buf_len = COLL_HAN_PACKBUF_PAYLOAD_BYTES; + int nbufs = MAX_BUF_COUNT; + for (int jbuf=0; jbuf 0; + int ii_more_sends_to_complete = nsend_req_pending > 0; + + int ii_more_recvs_to_post = recv_post_remaining_bytes > 0; + int ii_more_recvs_to_complete = recv_convertor_bytes_remaining > 0 || jtype_recv < ntypes_recv; + + + if ( !( ii_more_sends_to_post || ii_more_sends_to_complete || + ii_more_recvs_to_post || ii_more_recvs_to_complete) ) { + /* exit. All done! */ + break; + } + + if (jloop >= nreqs) { + /* Common Case: */ + /* wait for any send or recv to complete */ + rc = ompi_request_wait_any(nreqs, requests, &jreq, MPI_STATUS_IGNORE); + if (rc != 0 || jreq == MPI_UNDEFINED) { + return 1; + } + have_completion = 1; + } else { + /* priming the loop: post sends or recvs while have_completion=0. + + note: it is possible we have more buffers than data, so logic below + makes sure we don't infinite loop or send empty messages */ + jreq = jloop; + have_completion = 0; + } + int ii_send_req = jreq >= jfirst_sendreq; + char *req_buf = buf_items[jreq]->ptr; + + if (ii_send_req) { + + if (have_completion) { + /* send request has completed */ + nsend_req_pending--; + } + + ssize_t buf_remain = buf_len; + while (buf_remain > 0 && (jtype_send < ntypes_send || send_pack_bytes_remaining > 0) ) { + if (jtype_send < ntypes_send && send_pack_bytes_remaining==0) { + /* switch to next datatype and prepare convertor: */ + jtype_send++; + if (jtype_send < ntypes_send) { + opal_convertor_copy_and_prepare_for_send(ompi_mpi_local_convertor, + send_types[jtype_send], + send_counts[jtype_send], + send_from_addrs[jtype_send], + 0, + &send_convertor); + opal_convertor_get_packed_size( &send_convertor, &send_pack_bytes_remaining ); + } + } + + /* pack more data */ + if (send_pack_bytes_remaining > 0 && buf_remain > 0) { + /* pack into the buffer described by the iov */ + /* iovec: set the destination of the copy */ + size_t start_offset = buf_len - buf_remain; + iov.iov_base = (char*)(req_buf) + start_offset; + iov.iov_len = buf_remain; + iov_count = 1; + opal_convertor_pack(&send_convertor, &iov, &iov_count, &nbytes_pack); + buf_remain -= nbytes_pack; + send_pack_bytes_remaining -= nbytes_pack; + } + } + + /* post send */ + if (buf_len - buf_remain > 0) { + MCA_PML_CALL(isend + (req_buf, (buf_len-buf_remain), MPI_PACKED, jrank_sendto, + MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, + comm, &requests[jreq])); + nsend_req_pending++; + } else { + requests[jreq] = MPI_REQUEST_NULL; + } + + } else { /* recv request */ + if (have_completion) { + /* unpack data */ + ssize_t buf_remain = buf_len; + size_t buf_converted = 0; + while (true) { + if (jtype_recv < ntypes_recv && recv_convertor_bytes_remaining==0) { + /* switch to next datatype and prepare convertor: */ + jtype_recv++; + if (jtype_recv < ntypes_recv) { + opal_convertor_copy_and_prepare_for_recv(ompi_mpi_local_convertor, + recv_types[jtype_recv], + recv_counts[jtype_recv], + recv_to_addrs[jtype_recv], + 0, + &recv_convertor); + opal_convertor_get_packed_size( &recv_convertor, &recv_convertor_bytes_remaining ); + } + } + if (jtype_recv == ntypes_recv && recv_convertor_bytes_remaining==0 ) { + /* _all_ recving work is done! */ + buf_remain = 0; + } + if (buf_remain == 0) { break; } + + /* unpack more data */ + if (recv_convertor_bytes_remaining > 0) { + /* unpack from the buffer described by the iov */ + iov.iov_base = (char*)(req_buf) + buf_converted; + iov.iov_len = buf_remain; + iov_count = 1; + opal_convertor_unpack(&recv_convertor, &iov, &iov_count, &nbytes_pack); + + buf_remain -= nbytes_pack; + buf_converted += nbytes_pack; + recv_convertor_bytes_remaining -= nbytes_pack; + } + } + } + + size_t bytes_to_post = MIN(buf_len, recv_post_remaining_bytes); + if (bytes_to_post > 0) { + /* post a new recv */ + MCA_PML_CALL(irecv + (req_buf, bytes_to_post, MPI_PACKED, jrank_recvfrom, + MCA_COLL_BASE_TAG_ALLTOALLV, + comm, &requests[jreq])); + + /* update posted_recv_bytes */ + recv_post_remaining_bytes -= bytes_to_post; + } else { + requests[jreq] = MPI_REQUEST_NULL; + } + } + } + + OBJ_DESTRUCT(&send_convertor); + OBJ_DESTRUCT(&recv_convertor); + + for (int jbuf=0; jbufsuper, count_for_convertor, rbuf, 0, &convertor); + bufs_on_device = opal_convertor_on_device(&convertor); + need_bufs = opal_convertor_need_buffers(&convertor); + rc |= opal_convertor_cleanup(&convertor); + rc |= opal_convertor_copy_and_prepare_for_send(ompi_mpi_local_convertor, + &sdtype->super, count_for_convertor, sbuf, 0, &convertor); + bufs_on_device |= opal_convertor_on_device(&convertor); + opal_convertor_get_packed_size(&convertor, &packed_size_bytes); + for (int jrank=0; jrankc_coll->coll_allreduce( + &reduce_buf_input, &reduce_buf_output, 3, MPI_LONG_LONG, &ompi_mpi_op_sum.op, + comm, comm->c_coll->coll_allreduce_module ); + if (rc != OMPI_SUCCESS) {return rc;} + + if (reduce_buf_output[0] > 0) { + *use_smsc = 0; + /* can't proceed: at least one rank using GPU buffers. */ + } else if (reduce_buf_output[2] >= comm_size * mca_coll_han_component.han_alltoallv_smsc_noncontig_activation_limit) { + /* always proceed: enough ranks have non-contiguous data that the pack-and-send method is fastest. */ + *use_smsc = 1; + } else if (reduce_buf_output[1] >= comm_size * mca_coll_han_component.han_alltoallv_smsc_avg_send_limit) { + *use_smsc = 0; + /* don't proceed: messages are large and contiguous. It is faster to fall back to basic alg. */ + } else { + *use_smsc = 1; + } + + if (comm_rank == 0) { + opal_output_verbose(10, mca_coll_han_component.han_output, "alltoallv: decide_to_use_smsc_alg: " + "Ranks with GPU buffers: %lld (limit is 0). " + "Average send_size: %.1f bytes (limit is %ld). " + "Fraction with non-contiguous buffers: %.3f (activation limit: %.3f). " + "Continue with SMSC? ==>%s.\n", + reduce_buf_output[0], + ((double)reduce_buf_output[1])/comm_size, mca_coll_han_component.han_alltoallv_smsc_avg_send_limit, + ((double)reduce_buf_output[2])/comm_size, mca_coll_han_component.han_alltoallv_smsc_noncontig_activation_limit, + (*use_smsc) ? "Yes" : "No"); + } + return rc; +} + +int mca_coll_han_alltoallv_using_smsc( + const void *sbuf, + ompi_count_array_t scounts, + ompi_disp_array_t sdispls, + struct ompi_datatype_t *sdtype, + void* rbuf, + ompi_count_array_t rcounts, + ompi_disp_array_t rdispls, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + + OPAL_OUTPUT_VERBOSE((90, mca_coll_han_component.han_output, + "Entering mca_coll_han_alltoall_using_smsc\n")); + int rc; + + mca_coll_han_module_t *han_module = (mca_coll_han_module_t *)module; + + if (!mca_smsc || !mca_smsc_base_has_feature(MCA_SMSC_FEATURE_CAN_MAP)) { + /* Assume all hosts take this path together :-\ */ + opal_output_verbose(1, mca_coll_han_component.han_output, "in mca_coll_han_alltoallv_using_smsc, " + "but MCA_SMSC_FEATURE_CAN_MAP not available. Disqualifying this alg!\n"); + HAN_UNINSTALL_COLL_API(comm, han_module, alltoallv); + return han_module->previous_alltoallv(sbuf, scounts, sdispls, sdtype, rbuf, rcounts, rdispls, rdtype, + comm, han_module->previous_alltoallv_module); + } + + /* Create the subcommunicators */ + if( OMPI_SUCCESS != mca_coll_han_comm_create_new(comm, han_module) ) { + opal_output_verbose(1, mca_coll_han_component.han_output, + "han cannot handle alltoallv with this communicator. Fall back on another component\n"); + /* HAN cannot work with this communicator so fallback on all collectives */ + HAN_LOAD_FALLBACK_COLLECTIVES(comm, han_module); + return han_module->previous_alltoallv(sbuf, scounts, sdispls, sdtype, rbuf, rcounts, rdispls, rdtype, + comm, han_module->previous_alltoallv_module); + } + + /* Topo must be initialized to know rank distribution which then is used to + * determine if han can be used */ + mca_coll_han_topo_init(comm, han_module, 2); + if (han_module->are_ppn_imbalanced || !han_module->is_mapbycore){ + opal_output_verbose(1, mca_coll_han_component.han_output, + "han cannot handle alltoallv with this communicator (imbalance/!mapbycore). " + "Fall back on another component\n"); + /* Put back the fallback collective support and call it once. All + * future calls will then be automatically redirected. + */ + HAN_UNINSTALL_COLL_API(comm, han_module, alltoallv); + return han_module->previous_alltoallv(sbuf, scounts, sdispls, sdtype, rbuf, rcounts, rdispls, rdtype, + comm, han_module->previous_alltoallv_module); + } + + int w_rank = ompi_comm_rank(comm); + int w_size = ompi_comm_size(comm); + + int use_smsc; + rc = decide_to_use_smsc_alg(&use_smsc, + sbuf, scounts, sdispls, sdtype, rbuf, rcounts, rdispls, rdtype, comm); + if (rc != 0) { return rc; } + if (!use_smsc) { + return han_module->previous_alltoallv(sbuf, scounts, sdispls, sdtype, rbuf, rcounts, rdispls, rdtype, + comm, han_module->previous_alltoallv_module); + } + + ompi_communicator_t *low_comm = han_module->sub_comm[INTRA_NODE]; + ompi_communicator_t *up_comm = han_module->sub_comm[INTER_NODE]; + + /* information about sub-communicators */ + int low_rank = ompi_comm_rank(low_comm); + int low_size = ompi_comm_size(low_comm); + int up_size = ompi_comm_size(up_comm); + int up_rank = ompi_comm_rank(up_comm); + + struct gathered_data low_gather_in; + struct gathered_data *low_gather_out; + + + low_gather_in.stype_serialized_length = ddt_pack_datatype(&sdtype->super, NULL); + low_gather_in.rtype_serialized_length = ddt_pack_datatype(&rdtype->super, NULL); + uint8_t *serialization_buf; + + size_t serialization_buf_length = low_gather_in.stype_serialized_length + + low_gather_in.rtype_serialized_length + + sizeof(struct peer_counts)*w_size; + + /* allocate data */ + serialization_buf = malloc(serialization_buf_length); + low_gather_out = malloc(sizeof(*low_gather_out) * low_size); + struct peer_data *peers = malloc(sizeof(*peers) * low_size); + opal_datatype_t *peer_send_types = malloc(sizeof(*peer_send_types) * low_size); + opal_datatype_t *peer_recv_types = malloc(sizeof(*peer_recv_types) * low_size); + + low_gather_in.serialization_buffer = serialization_buf; + low_gather_in.sbuf = (void*)sbuf; // discard const + low_gather_in.rbuf = rbuf; + + low_gather_in.sbuf_length = 0; + low_gather_in.rbuf_length = 0; + ptrdiff_t sextent; + ptrdiff_t rextent; + rc = ompi_datatype_type_extent( sdtype, &sextent); + rc = ompi_datatype_type_extent( rdtype, &rextent); + + /* calculate the extent of our buffers so that peers can mmap the whole thing */ + for (int jrankw=0; jrankw low_gather_in.sbuf_length) { + low_gather_in.sbuf_length = sz; + } + sz = (ompi_disp_array_get(rdispls,jrankw) + ompi_count_array_get(rcounts,jrankw))*rextent; + if (sz > low_gather_in.rbuf_length) { + low_gather_in.rbuf_length = sz; + } + } + + /* pack the serialization buffer: first the array of counts */ + size_t buf_packed = 0; + for (int jrankw=0; jrankwscount = ompi_count_array_get(scounts,jrankw); + counts->rcount = ompi_count_array_get(rcounts,jrankw); + + counts->sdispl = ompi_disp_array_get(sdispls,jrankw); + counts->rdispl = ompi_disp_array_get(rdispls,jrankw); + buf_packed += sizeof(*counts); + } + /* pack the serialization buffer: next the send and recv dtypes */ + buf_packed += ddt_pack_datatype(&sdtype->super, serialization_buf + buf_packed); + buf_packed += ddt_pack_datatype(&rdtype->super, serialization_buf + buf_packed); + assert(buf_packed == serialization_buf_length); + + rc = low_comm->c_coll->coll_allgather(&low_gather_in, sizeof(low_gather_in), MPI_BYTE, + low_gather_out, sizeof(low_gather_in), MPI_BYTE, low_comm, + low_comm->c_coll->coll_allgather_module); + if (rc != 0) { + OPAL_OUTPUT_VERBOSE((40, mca_coll_han_component.han_output, + "Allgather failed with %d\n",rc)); + goto cleanup; + } + + /* + In this loop we unpack the data received in allgather: + - use SMSC to memory-map the serialization buffer + - construct dtype objects from the serialization buffer + - set pointers to read counts and displacements directly from serialization buffer + - Memory-map the send-bufs. + + */ + for (int jrank=0; jranksuper; + peers[jrank].sendtype = &rdtype->super; + peers[jrank].map_ctx[0] = NULL; + peers[jrank].map_ctx[1] = NULL; + peers[jrank].map_ctx[2] = NULL; + continue; + } + struct gathered_data *gathered = &low_gather_out[jrank]; + struct ompi_proc_t* ompi_proc = ompi_comm_peer_lookup(low_comm, jrank); + mca_smsc_endpoint_t *smsc_ep; + smsc_ep = mca_coll_han_get_smsc_endpoint(ompi_proc); + + uint8_t *peer_serialization_buf; + size_t peer_serialization_buf_length; + peer_serialization_buf_length = w_size * sizeof(struct peer_counts); + peer_serialization_buf_length += gathered->rtype_serialized_length; + peer_serialization_buf_length += gathered->stype_serialized_length; + + /* mmap the buffers */ + peers[jrank].map_ctx[0] = mca_smsc->map_peer_region( + smsc_ep, + MCA_RCACHE_FLAGS_PERSIST, + gathered->serialization_buffer, + peer_serialization_buf_length, + (void**) &peer_serialization_buf ); + peers[jrank].map_ctx[1] = mca_smsc->map_peer_region( + smsc_ep, + MCA_RCACHE_FLAGS_PERSIST, + gathered->sbuf, + gathered->sbuf_length, + (void**)&peers[jrank].sbuf ); + peers[jrank].map_ctx[2] = mca_smsc->map_peer_region( + smsc_ep, + MCA_RCACHE_FLAGS_PERSIST, + gathered->rbuf, + gathered->rbuf_length, + &peers[jrank].rbuf ); + + /* point the counts pointer into the mmapped serialization buffer */ + peers[jrank].counts = (struct peer_counts*)peer_serialization_buf; + peer_serialization_buf += w_size * sizeof(struct peer_counts); + + /* unpack the dtypes */ + peer_serialization_buf += ddt_unpack_datatype( &peer_send_types[jrank], peer_serialization_buf); + peer_serialization_buf += ddt_unpack_datatype( &peer_recv_types[jrank], peer_serialization_buf); + peers[jrank].sendtype = &peer_send_types[jrank]; + peers[jrank].recvtype = &peer_recv_types[jrank]; + } + + void **send_from_addrs = malloc(sizeof(*send_from_addrs)*low_size); + void **recv_to_addrs = malloc(sizeof(*recv_to_addrs)*low_size); + size_t *send_counts = malloc(sizeof(*send_counts)*low_size); + size_t *recv_counts = malloc(sizeof(*recv_counts)*low_size); + opal_datatype_t **send_types = malloc(sizeof(*send_types)*low_size); + opal_datatype_t **recv_types = malloc(sizeof(*recv_types)*low_size); + + /**** + * Main exchange loop + ****/ + int nloops = up_size; + for (int jloop=0; jloopsuper); + + recv_to_addrs[jlow] = (uint8_t*)rbuf + ompi_disp_array_get(rdispls,remote_wrank)*rextent; + recv_counts[jlow] = ompi_count_array_get(rcounts,remote_wrank); + recv_types[jlow] = &(rdtype->super); + } + + int ntypes_send = low_size; + int ntypes_recv = low_size; + + rc = alltoallv_sendrecv_w( + send_from_addrs, send_counts, send_types, jrank_sendto, ntypes_send, + recv_to_addrs, recv_counts, recv_types, jrank_recvfrom, ntypes_recv, + comm); + if (rc != 0) goto cleanup; + } + + free(send_from_addrs); + free(recv_to_addrs); + free(send_counts); + free(recv_counts); + free(send_types); + free(recv_types); + + rc=0; + low_comm->c_coll->coll_barrier(low_comm, low_comm->c_coll->coll_barrier_module); + +cleanup: + for (int jlow=0; jlowunmap_peer_region(peers[jlow].map_ctx[jbuf]); + } + } + } + free(serialization_buf); + free(low_gather_out); + free(peers); + free(peer_send_types); + free(peer_recv_types); + + OPAL_OUTPUT_VERBOSE((40, mca_coll_han_component.han_output, + "Alltoall Complete with %d\n",rc)); + return rc; + +} diff --git a/ompi/mca/coll/han/coll_han_component.c b/ompi/mca/coll/han/coll_han_component.c index bb11b7d5ab7..b3362314602 100644 --- a/ompi/mca/coll/han/coll_han_component.c +++ b/ompi/mca/coll/han/coll_han_component.c @@ -118,7 +118,24 @@ static int han_open(void) mca_coll_han_component.han_output = ompi_coll_base_framework.framework_output; } - + OBJ_CONSTRUCT(&mca_coll_han_component.pack_buffers, opal_free_list_t); + + int ret = opal_free_list_init( + /* *flist,frag_size,frag_alignment */ + &mca_coll_han_component.pack_buffers, sizeof(opal_free_list_item_t), 8, + /* opal_class_t *frag_class */ + OBJ_CLASS(opal_free_list_item_t), + /* payload_buffer_size, payload_buffer_alignment */ + COLL_HAN_PACKBUF_PAYLOAD_BYTES, 8, + /* num_elements_to_alloc, max_elements_to_alloc, num_elements_per_alloc */ + 0, 32, 8, + /* *mpool, rcache_reg_flags, *rcache, */ + NULL, 0, NULL, + /* fn_t item_init, void *ctx */ + NULL, NULL); + if (ret != 0) { + printf("han: initializing free list got %d\n",ret); + } return mca_coll_han_init_dynamic_rules(); } @@ -411,6 +428,26 @@ static int han_register(void) OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, &cs->han_alltoall_pstages); + cs->han_alltoallv_low_module = 0; + (void) mca_coll_han_query_module_from_mca(c, "alltoallv_lower_module", + "low level module for alltoallv, 0 tuned, 1 sm ", + OPAL_INFO_LVL_9, &cs->han_alltoallv_low_module, + &cs->han_op_module_name.alltoallv.han_op_low_module_name); + cs->han_alltoallv_smsc_avg_send_limit = 8192; + (void) mca_base_component_var_register(c, "alltoallv_smsc_avg_send_limit", + "The per-rank averaged send bytes limit above which smsc-based alltoallv will disqualify itself.", + MCA_BASE_VAR_TYPE_INT64_T, NULL, 0, 0, + OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &cs->han_alltoallv_smsc_avg_send_limit); + cs->han_alltoallv_smsc_noncontig_activation_limit = 0.10; + (void) mca_base_component_var_register(c, "alltoallv_smsc_noncontig_limit", + "The fractional (0.00-1.00) limit of peers in the communicator which have " + "strided or otherwise non-contiguous data buffers. Above this limit " + "smsc-based alltoallv will ignore the avg_send_limit, and always remain active.", + MCA_BASE_VAR_TYPE_DOUBLE, NULL, 0, 0, + OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &cs->han_alltoallv_smsc_noncontig_activation_limit); + cs->han_reproducible = 0; (void) mca_base_component_var_register(c, "reproducible", "whether we need reproducible results " diff --git a/ompi/mca/coll/han/coll_han_dynamic.c b/ompi/mca/coll/han/coll_han_dynamic.c index 69f25b40757..3d483b35aa9 100644 --- a/ompi/mca/coll/han/coll_han_dynamic.c +++ b/ompi/mca/coll/han/coll_han_dynamic.c @@ -42,6 +42,7 @@ bool mca_coll_han_is_coll_dynamic_implemented(COLLTYPE_T coll_id) case ALLGATHERV: case ALLREDUCE: case ALLTOALL: + case ALLTOALLV: case BARRIER: case BCAST: case GATHER: @@ -1635,3 +1636,114 @@ mca_coll_han_alltoall_intra_dynamic(const void *sbuf, size_t scount, comm, sub_module); } + + +/* + * alltoallv selector: + * On a sub-communicator, checks the stored rules to find the module to use + * On the global communicator, calls the han collective implementation, or + * calls the correct module if fallback mechanism is activated + */ +int +mca_coll_han_alltoallv_intra_dynamic( + ALLTOALLV_BASE_ARGS, + mca_coll_base_module_t *module) +{ + mca_coll_han_module_t *han_module = (mca_coll_han_module_t*) module; + TOPO_LVL_T topo_lvl = han_module->topologic_level; + mca_coll_base_module_alltoallv_fn_t alltoallv; + mca_coll_base_module_t *sub_module; + int rank, verbosity = 0; + + if (!han_module->enabled) { + return han_module->previous_alltoallv(ALLTOALLV_BASE_ARG_NAMES, + han_module->previous_alltoallv_module); + } + + /* v collectives do not support message-size based dynamic rules */ + sub_module = get_module(ALLTOALLV, + MCA_COLL_HAN_ANY_MESSAGE_SIZE, + comm, + han_module); + + /* First errors are always printed by rank 0 */ + rank = ompi_comm_rank(comm); + if( (0 == rank) && (han_module->dynamic_errors < mca_coll_han_component.max_dynamic_errors) ) { + verbosity = 30; + } + + if(NULL == sub_module) { + /* + * No valid collective module from dynamic rules + * nor from mca parameter + */ + han_module->dynamic_errors++; + opal_output_verbose(verbosity, mca_coll_han_component.han_output, + "coll:han:mca_coll_han_alltoallv_intra_dynamic " + "HAN did not find any valid module for collective %d (%s) " + "with topological level %d (%s) on communicator (%s/%s). " + "Please check dynamic file/mca parameters\n", + ALLTOALLV, mca_coll_base_colltype_to_str(ALLTOALLV), + topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl), + ompi_comm_print_cid(comm), comm->c_name); + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "HAN/ALLTOALLV: No module found for the sub-communicator. " + "Falling back to another component\n")); + alltoallv = han_module->previous_alltoallv; + sub_module = han_module->previous_alltoallv_module; + } else if (NULL == sub_module->coll_alltoallv) { + /* + * No valid collective from dynamic rules + * nor from mca parameter + */ + han_module->dynamic_errors++; + opal_output_verbose(verbosity, mca_coll_han_component.han_output, + "coll:han:mca_coll_han_alltoallv_intra_dynamic " + "HAN found valid module for collective %d (%s) " + "with topological level %d (%s) on communicator (%s/%s) " + "but this module cannot handle this collective. " + "Please check dynamic file/mca parameters\n", + ALLTOALLV, mca_coll_base_colltype_to_str(ALLTOALLV), + topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl), + ompi_comm_print_cid(comm), comm->c_name); + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "HAN/ALLTOALLV: the module found for the sub-" + "communicator cannot handle the ALLTOALLV operation. " + "Falling back to another component\n")); + alltoallv = han_module->previous_alltoallv; + sub_module = han_module->previous_alltoallv_module; + } else if (GLOBAL_COMMUNICATOR == topo_lvl && sub_module == module) { + /* + * No fallback mechanism activated for this configuration + * sub_module is valid + * sub_module->coll_alltoallv is valid and point to this function + * Call han topological collective algorithm + */ + int algorithm_id = get_algorithm(ALLTOALLV, + MCA_COLL_HAN_ANY_MESSAGE_SIZE, + comm, + han_module); + alltoallv = (mca_coll_base_module_alltoallv_fn_t)mca_coll_han_algorithm_id_to_fn(ALLTOALLV, algorithm_id); + if (NULL == alltoallv) { /* default behaviour */ + alltoallv = mca_coll_han_alltoallv_using_smsc; + } + } else { + /* + * If we get here: + * sub_module is valid + * sub_module->coll_alltoallv is valid + * They points to the collective to use, according to the dynamic rules + * Selector's job is done, call the collective + */ + alltoallv = sub_module->coll_alltoallv; + } + + /* + * If we get here: + * sub_module is valid + * sub_module->coll_alltoallv is valid + * They points to the collective to use, according to the dynamic rules + * Selector's job is done, call the collective + */ + return alltoallv(ALLTOALLV_BASE_ARG_NAMES, sub_module); +} diff --git a/ompi/mca/coll/han/coll_han_module.c b/ompi/mca/coll/han/coll_han_module.c index cf2c84e27fd..df6f0d08c2e 100644 --- a/ompi/mca/coll/han/coll_han_module.c +++ b/ompi/mca/coll/han/coll_han_module.c @@ -49,6 +49,7 @@ static int mca_coll_han_module_disable(mca_coll_base_module_t * module, static void han_module_clear(mca_coll_han_module_t *han_module) { CLEAN_PREV_COLL(han_module, alltoall); + CLEAN_PREV_COLL(han_module, alltoallv); CLEAN_PREV_COLL(han_module, allgather); CLEAN_PREV_COLL(han_module, allgatherv); CLEAN_PREV_COLL(han_module, allreduce); @@ -235,7 +236,7 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) han_module->super.coll_module_disable = mca_coll_han_module_disable; han_module->super.coll_alltoall = mca_coll_han_alltoall_intra_dynamic; - han_module->super.coll_alltoallv = NULL; + han_module->super.coll_alltoallv = mca_coll_han_alltoallv_intra_dynamic; han_module->super.coll_alltoallw = NULL; han_module->super.coll_exscan = NULL; han_module->super.coll_reduce_scatter = NULL; @@ -297,6 +298,7 @@ mca_coll_han_module_enable(mca_coll_base_module_t * module, mca_coll_han_module_t * han_module = (mca_coll_han_module_t*) module; HAN_INSTALL_COLL_API(comm, han_module, alltoall); + HAN_INSTALL_COLL_API(comm, han_module, alltoallv); HAN_INSTALL_COLL_API(comm, han_module, allgather); HAN_INSTALL_COLL_API(comm, han_module, allgatherv); HAN_INSTALL_COLL_API(comm, han_module, allreduce); @@ -325,6 +327,7 @@ mca_coll_han_module_disable(mca_coll_base_module_t * module, mca_coll_han_module_t * han_module = (mca_coll_han_module_t *) module; HAN_UNINSTALL_COLL_API(comm, han_module, alltoall); + HAN_UNINSTALL_COLL_API(comm, han_module, alltoallv); HAN_UNINSTALL_COLL_API(comm, han_module, allgather); HAN_UNINSTALL_COLL_API(comm, han_module, allgatherv); HAN_UNINSTALL_COLL_API(comm, han_module, allreduce); diff --git a/ompi/mca/coll/han/coll_han_subcomms.c b/ompi/mca/coll/han/coll_han_subcomms.c index 71e01a4bba3..d3327df9906 100644 --- a/ompi/mca/coll/han/coll_han_subcomms.c +++ b/ompi/mca/coll/han/coll_han_subcomms.c @@ -79,6 +79,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, * Gather + Bcast may be called by the allgather implementation */ HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, alltoall); + HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, alltoallv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, allgather); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, allreduce); @@ -111,6 +112,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, han_module->enabled = false; /* entire module set to pass-through from now on */ /* restore saved collectives */ HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoall); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoallv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); @@ -189,6 +191,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, /* Restore the saved collectives */ HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoall); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoallv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); @@ -248,6 +251,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, * Gather + Bcast may be called by the allgather implementation */ HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, alltoall); + HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, alltoallv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, allgather); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, allreduce); @@ -275,6 +279,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, if( local_procs == 1 ) { /* restore saved collectives */ HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoall); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoallv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce); @@ -370,6 +375,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, /* Reset the saved collectives to point back to HAN */ HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoall); + HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, alltoallv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgatherv); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allgather); HAN_SUBCOM_RESTORE_COLLECTIVE(fallbacks, comm, han_module, allreduce);