Skip to content

Commit

Permalink
coll/han/alltoallv: Fix for when types have negative LB
Browse files Browse the repository at this point in the history
Previously this function considered only displ, count, and extent.
Since the function uses XPMEM to explicitly expose memory regions, we
must also be aware of types that have negative lower bounds and might
access data _before_ the user-provided pointer.

This change more accurately compute the true upper and lower bounds of
all memory accesses, both to insure we don't try to map regions of
memory that may not be in our VM page table, and to ensure we expose
all memory that will be accessed.

Signed-off-by: Luke Robison <[email protected]>
  • Loading branch information
lrbison committed Sep 18, 2024
1 parent a50519c commit 4fafbfe
Showing 1 changed file with 125 additions and 34 deletions.
159 changes: 125 additions & 34 deletions ompi/mca/coll/han/coll_han_alltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,74 @@ struct gathered_data {
size_t rtype_serialized_length;
size_t sbuf_length;
size_t rbuf_length;
ssize_t sbuf_lb;
ssize_t rbuf_lb;
void *sbuf;
void *rbuf;
void *serialization_buffer;
};

/**
given a count, displ, and type, compute the true lb and ub for all data used by the count and displ arguments.
Note the byte at UB is not accessed. so full_true_extent is UB-LB.
Consider the most difficult case: a resized type with non-zero LB and
extent != true_extent, and true_LB != LB. In the figure below:
X represents the 0-point of the user's `buf`
x represents data accessed by the type
- represents data spanned by the type when count > 1
. represents data not accessed or spanned by the type.
+ LB = -5
| + true_LB = -2
| | + buf (0)
| | |
...---xxXxxx----...
| |<-->| true_extent = 6
|<--------->| extent = 13
When there are 2 items, the full true extent is
...---xxXxxx-------xxxxxx----...
| |<--------------->| true_extent = 19 ie: extent*(n-1) + true_extent
|<---------------------->| extent = 26 ie: extent*n
*/
static int han_alltoallv_dtype_get_many_true_lb_ub(
ompi_datatype_t *dtype,
ptrdiff_t count,
ptrdiff_t displ,
ptrdiff_t *full_true_lb,
ptrdiff_t *full_true_ub ) {

ptrdiff_t extent, true_extent, full_true_extent;
ptrdiff_t lb, true_lb;
int rc;

/* note: full_true_lb and full_true_ub are undefined when count == 0!
In this case, set them to 0 and 0. */
*full_true_lb = 0;
*full_true_ub = 0;
if (count == 0) {
return 0;
}

rc = ompi_datatype_get_true_extent( dtype, &true_lb, &true_extent);
if (rc) { return rc; }
rc = ompi_datatype_get_extent( dtype, &lb, &extent);
if (rc) { return rc; }

/* extent of data */
full_true_extent = extent*MAX(0,count-1) + true_extent;
/* for displ, only extent matters (not true_extent)*/
ptrdiff_t displ_bytes = displ * extent;

/* now compute full true LB/UB including displ. */
*full_true_lb = true_lb + displ_bytes;
*full_true_ub = *full_true_lb + full_true_extent;

return 0;
}

/* Serialize the datatype into the buffer and return buffer length.
If buf is NULL, just return the length of the required buffer. */
static size_t ddt_pack_datatype(opal_datatype_t* type, uint8_t* buf)
Expand Down Expand Up @@ -142,7 +205,13 @@ static size_t ddt_unpack_datatype(opal_datatype_t* type, uint8_t* buf)
return length;
}

/* basic implementation: send all buffers without packing keeping a limited number in flight. */
/* Simple implementation: send all buffers without packing, but still keeping a
limited number in flight.
Note: CMA on XPMEM-mapped buffers does not work. If the low-level network
provider attempts to use CMA to implement send/recv, then errors will
occur!
*/
static inline int alltoallv_sendrecv_w_direct_for_debugging(
void **send_from_addrs,
size_t *send_counts,
Expand Down Expand Up @@ -185,6 +254,7 @@ static inline int alltoallv_sendrecv_w_direct_for_debugging(
} else {
have_completion = 1;
rc = ompi_request_wait_any( nreqs, requests, &jreq, MPI_STATUS_IGNORE );
if (rc) break;
}
int ii_send_req = jreq >= jfirst_sendreq;
if (have_completion) {
Expand Down Expand Up @@ -725,28 +795,49 @@ int mca_coll_han_alltoallv_using_smsc(
opal_datatype_t *peer_recv_types = malloc(sizeof(*peer_recv_types) * low_size);

low_gather_in.serialization_buffer = serialization_buf;
low_gather_in.sbuf = (void*)sbuf; // discard const
low_gather_in.sbuf = (void*)sbuf; // cast to discard the const
low_gather_in.rbuf = rbuf;

low_gather_in.sbuf_length = 0;
low_gather_in.rbuf_length = 0;
ptrdiff_t sextent;
ptrdiff_t rextent;
rc = ompi_datatype_type_extent( sdtype, &sextent);
rc = ompi_datatype_type_extent( rdtype, &rextent);
ptrdiff_t r_extent, r_lb;
rc = ompi_datatype_get_extent( rdtype, &r_lb, &r_extent);

/* calculate the extent of our buffers so that peers can mmap the whole thing */
ssize_t min_send_lb = SSIZE_MAX;
ssize_t max_send_ub = -SSIZE_MAX-1;
ssize_t min_recv_lb = SSIZE_MAX;
ssize_t max_recv_ub = -SSIZE_MAX-1;
/* calculate the maximal bounds of our buffers so that peers can mmap the whole thing. */
for (int jrankw=0; jrankw<w_size; jrankw++) {
size_t sz;
sz = (ompi_disp_array_get(sdispls,jrankw) + ompi_count_array_get(scounts,jrankw))*sextent;
if (sz > low_gather_in.sbuf_length) {
low_gather_in.sbuf_length = sz;
ptrdiff_t displ;
ptrdiff_t count;
ssize_t send_lb, send_ub, recv_lb, recv_ub;

count = ompi_count_array_get(scounts,jrankw);
displ = ompi_disp_array_get(sdispls,jrankw);
if (count > 0) {
han_alltoallv_dtype_get_many_true_lb_ub( sdtype, count, displ, &send_lb, &send_ub);
min_send_lb = MIN( min_send_lb, send_lb );
max_send_ub = MAX( max_send_ub, send_ub );
}
sz = (ompi_disp_array_get(rdispls,jrankw) + ompi_count_array_get(rcounts,jrankw))*rextent;
if (sz > low_gather_in.rbuf_length) {
low_gather_in.rbuf_length = sz;

count = ompi_count_array_get(rcounts,jrankw);
displ = ompi_disp_array_get(rdispls,jrankw);
if (count > 0) {
han_alltoallv_dtype_get_many_true_lb_ub( rdtype, count, displ, &recv_lb, &recv_ub);
min_recv_lb = MIN( min_recv_lb, recv_lb );
max_recv_ub = MAX( max_recv_ub, recv_ub );
}
}
low_gather_in.sbuf_length = 0;
if (max_send_ub > min_send_lb) {
low_gather_in.sbuf_length = max_send_ub - min_send_lb;
low_gather_in.sbuf_lb = min_send_lb;
}

low_gather_in.rbuf_length = 0;
if (max_recv_ub > min_recv_lb) {
low_gather_in.rbuf_length = max_recv_ub - min_recv_lb;
low_gather_in.rbuf_lb = min_recv_lb;
}

/* pack the serialization buffer: first the array of counts */
size_t buf_packed = 0;
Expand Down Expand Up @@ -782,18 +873,21 @@ int mca_coll_han_alltoallv_using_smsc(
*/
for (int jrank=0; jrank<low_size; jrank++) {
void *tmp_ptr;
peers[jrank].map_ctx[0] = NULL;
peers[jrank].map_ctx[1] = NULL;
peers[jrank].map_ctx[2] = NULL;

if (jrank == low_rank) {
/* special case for ourself */
peers[jrank].counts = (struct peer_counts *)serialization_buf;
peers[jrank].sbuf = sbuf;
peers[jrank].rbuf = rbuf;
peers[jrank].recvtype = &rdtype->super;
peers[jrank].sendtype = &sdtype->super;
peers[jrank].map_ctx[0] = NULL;
peers[jrank].map_ctx[1] = NULL;
peers[jrank].map_ctx[2] = NULL;
continue;
}

struct gathered_data *gathered = &low_gather_out[jrank];
struct ompi_proc_t* ompi_proc = ompi_comm_peer_lookup(low_comm, jrank);
mca_smsc_endpoint_t *smsc_ep;
Expand All @@ -812,18 +906,15 @@ int mca_coll_han_alltoallv_using_smsc(
gathered->serialization_buffer,
peer_serialization_buf_length,
(void**) &peer_serialization_buf );
peers[jrank].map_ctx[1] = mca_smsc->map_peer_region(
smsc_ep,
MCA_RCACHE_FLAGS_PERSIST,
gathered->sbuf,
gathered->sbuf_length,
(void**)&peers[jrank].sbuf );
peers[jrank].map_ctx[2] = mca_smsc->map_peer_region(
smsc_ep,
MCA_RCACHE_FLAGS_PERSIST,
gathered->rbuf,
gathered->rbuf_length,
&peers[jrank].rbuf );
if (gathered->sbuf_length > 0) {
peers[jrank].map_ctx[1] = mca_smsc->map_peer_region(
smsc_ep,
MCA_RCACHE_FLAGS_PERSIST,
(char*)gathered->sbuf + gathered->sbuf_lb,
gathered->sbuf_length,
&tmp_ptr );
peers[jrank].sbuf = (char*)tmp_ptr - gathered->sbuf_lb;
}

/* point the counts pointer into the mmapped serialization buffer */
peers[jrank].counts = (struct peer_counts*)peer_serialization_buf;
Expand Down Expand Up @@ -867,10 +958,10 @@ int mca_coll_han_alltoallv_using_smsc(

send_from_addrs[jlow] = from_addr;
send_counts[jlow] = peers[jlow].counts[jrank_sendto].scount;
send_types[jlow] = peers[jlow].sendtype;
// send_types[jlow] = &(sdtype->super);
send_types[jlow] = peers[jlow].sendtype;


recv_to_addrs[jlow] = (uint8_t*)rbuf + ompi_disp_array_get(rdispls,remote_wrank)*rextent;
recv_to_addrs[jlow] = (uint8_t*)rbuf + ompi_disp_array_get(rdispls,remote_wrank)*r_extent;
recv_counts[jlow] = ompi_count_array_get(rcounts,remote_wrank);
recv_types[jlow] = &(rdtype->super);
}
Expand Down

0 comments on commit 4fafbfe

Please sign in to comment.