diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 8b4987c8bb..b176dfbd8c 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -436,23 +436,7 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} - # Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and - # `ofs` are symbolic objects. This mapper tells what data values should be - # sent (OWNED) or received (HALO) given dimension and side - mapper = {} - for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): - if d0 in fixed: - continue - sizes = [] - ofs = [] - for d1 in f.dimensions: - if d1 in fixed: - ofs.append(fixed[d1]) - else: - meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) - ofs.append(meta.offset) - sizes.append(meta.size) - mapper[(d0, side, region)] = (sizes, ofs) + mapper = self._make_basic_mapper(f, fixed) body = [] for d in f.dimensions: @@ -484,6 +468,29 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): return HaloUpdate('haloupdate%s' % key, iet, parameters) + def _make_basic_mapper(self, f, fixed): + """ + Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and + `ofs` are symbolic objects. This mapper tells what data values should be + sent (OWNED) or received (HALO) given dimension and side + """ + mapper = {} + for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): + if d0 in fixed: + continue + sizes = [] + ofs = [] + for d1 in f.dimensions: + if d1 in fixed: + ofs.append(fixed[d1]) + else: + meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) + ofs.append(meta.offset) + sizes.append(meta.size) + mapper[(d0, side, region)] = (sizes, ofs) + + return mapper + def _call_haloupdate(self, name, f, hse, *args): comm = f.grid.distributor._obj_comm nb = f.grid.distributor._obj_neighborhood @@ -527,6 +534,121 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits): return List(body=body) +class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder): + + """ + A BasicHaloExchangeBuilder making use of pre-allocated buffers for + message size. + + Generates: + + haloupdate() + compute() + """ + + def _make_msg(self, f, hse, key): + # Pass the fixed mapper e.g. {t: otime} + fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} + + return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed) + + def _make_sendrecv(self, f, hse, key, msg=None): + cast = cast_mapper[(f.c0.dtype, '*')] + comm = f.grid.distributor._obj_comm + + bufg = FieldFromPointer(msg._C_field_bufg, msg) + bufs = FieldFromPointer(msg._C_field_bufs, msg) + + ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] + ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] + + fromrank = Symbol(name='fromrank') + torank = Symbol(name='torank') + + sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) + for i in range(len(f._dist_dimensions))] + + arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg + gather = Gather('gather%s' % key, arguments) + # The `gather` is unnecessary if sending to MPI.PROC_NULL + gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) + + arguments = [cast(bufs)] + sizes + list(f.handles) + ofss + scatter = Scatter('scatter%s' % key, arguments) + # The `scatter` must be guarded as we must not alter the halo values along + # the domain boundary, where the sender is actually MPI.PROC_NULL + scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) + + count = reduce(mul, sizes, 1)*dtype_len(f.dtype) + rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) + rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) + recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)), + fromrank, Integer(13), comm, rrecv]) + send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)), + torank, Integer(13), comm, rsend]) + + waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) + waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) + + iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) + + parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) + + return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs) + + def _call_sendrecv(self, name, *args, msg=None, haloid=None): + # Drop `sizes` as this HaloExchangeBuilder conveys them through `msg` + f, _, ofsg, ofss, fromrank, torank, comm = args + msg = Byref(IndexedPointer(msg, haloid)) + return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) + + def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): + distributor = f.grid.distributor + nb = distributor._obj_neighborhood + comm = distributor._obj_comm + + fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} + + mapper = self._make_basic_mapper(f, fixed) + + body = [] + for d in f.dimensions: + if d in fixed: + continue + + name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) + rpeer = FieldFromPointer(name, nb) + name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) + lpeer = FieldFromPointer(name, nb) + + if (d, LEFT) in hse.halos: + # Sending to left, receiving from right + lsizes, lofs = mapper[(d, LEFT, OWNED)] + rsizes, rofs = mapper[(d, RIGHT, HALO)] + args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm] + body.append(self._call_sendrecv(sendrecv.name, *args, haloid=len(body), + **kwargs)) + + if (d, RIGHT) in hse.halos: + # Sending to right, receiving from left + rsizes, rofs = mapper[(d, RIGHT, OWNED)] + lsizes, lofs = mapper[(d, LEFT, HALO)] + args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm] + body.append(self._call_sendrecv(sendrecv.name, *args, haloid=len(body), + **kwargs)) + + iet = List(body=body) + + parameters = list(f.handles) + [comm, nb] + list(fixed.values()) + [kwargs['msg']] + + return HaloUpdate('haloupdate%s' % key, iet, parameters) + + def _call_haloupdate(self, name, f, hse, msg): + call = super()._call_haloupdate(name, f, hse) + call = call._rebuild(arguments=call.arguments + (msg,)) + return call + + class DiagHaloExchangeBuilder(BasicHaloExchangeBuilder): """ @@ -1003,6 +1125,7 @@ def _call_poke(self, poke): mpi_registry = { 'basic': BasicHaloExchangeBuilder, + 'basic2': Basic2HaloExchangeBuilder, 'diag': DiagHaloExchangeBuilder, 'diag2': Diag2HaloExchangeBuilder, 'overlap': OverlapHaloExchangeBuilder, @@ -1112,7 +1235,7 @@ class MPIRequestObject(LocalObject): dtype = type('MPI_Request', (c_void_p,), {}) -class MPIMsg(CompositeObject): +class MPIMsgBase(CompositeObject): _C_field_bufs = 'bufs' _C_field_bufg = 'bufg' @@ -1135,17 +1258,6 @@ class MPIMsg(CompositeObject): __rargs__ = ('name', 'target', 'halos') - def __init__(self, name, target, halos): - self._target = target - self._halos = halos - - super().__init__(name, 'msg', self.fields) - - # Required for buffer allocation/deallocation before/after jumping/returning - # to/from C-land - self._allocator = None - self._memfree_args = [] - def __del__(self): self._C_memfree() @@ -1184,6 +1296,16 @@ def _as_number(self, v, args): assert args is not None return int(subs_op_args(v, args)) + def _allocate_buffers(self, f, shape, entry): + entry.sizes = (c_int*len(shape))(*shape) + size = reduce(mul, shape)*dtype_len(self.target.dtype) + ctype = dtype_to_ctype(f.dtype) + entry.bufg, bufg_memfree_args = self._allocator._alloc_C_libcall(size, ctype) + entry.bufs, bufs_memfree_args = self._allocator._alloc_C_libcall(size, ctype) + # The `memfree_args` will be used to deallocate the buffer upon + # returning from C-land + self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args]) + def _arg_defaults(self, allocator, alias, args=None): # Lazy initialization if `allocator` is necessary as the `allocator` # type isn't really known until an Operator is constructed @@ -1201,17 +1323,9 @@ def _arg_defaults(self, allocator, alias, args=None): except AttributeError: assert side == CENTER shape.append(self._as_number(f._size_domain[dim], args)) - entry.sizes = (c_int*len(shape))(*shape) # Allocate the send/recv buffers - size = reduce(mul, shape)*dtype_len(self.target.dtype) - ctype = dtype_to_ctype(f.dtype) - entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype) - entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype) - - # The `memfree_args` will be used to deallocate the buffer upon - # returning from C-land - self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args]) + self._allocate_buffers(f, shape, entry) return {self.name: self.value} @@ -1232,6 +1346,108 @@ def _arg_apply(self, *args, **kwargs): self._C_memfree() +class MPIMsg(MPIMsgBase): + + def __init__(self, name, target, halos): + self._target = target + self._halos = halos + + super().__init__(name, 'msg', self.fields) + + # Required for buffer allocation/deallocation before/after jumping/returning + # to/from C-land + self._allocator = None + self._memfree_args = [] + + def _arg_defaults(self, allocator, alias, args=None): + # Lazy initialization if `allocator` is necessary as the `allocator` + # type isn't really known until an Operator is constructed + self._allocator = allocator + + f = alias or self.target.c0 + for i, halo in enumerate(self.halos): + entry = self.value[i] + + # Buffer shape for this peer + shape = [] + for dim, side in zip(*halo): + try: + shape.append(getattr(f._size_owned[dim], side.name)) + except AttributeError: + assert side is CENTER + shape.append(self._as_number(f._size_domain[dim], args)) + + # Allocate the send/recv buffers + self._allocate_buffers(f, shape, entry) + + return {self.name: self.value} + + +class MPIMsgBasic2(MPIMsgBase): + + def __init__(self, name, target, halos, fixed=None): + self._target = target + self._halos = halos + + super().__init__(name, 'msg', self.fields) + + # Required for buffer allocation/deallocation before/after jumping/returning + # to/from C-land + self._fixed = fixed + self._allocator = None + self._memfree_args = [] + + def _arg_defaults(self, allocator, alias, args=None): + # Lazy initialization if `allocator` is necessary as the `allocator` + # type isn't really known until an Operator is constructed + self._allocator = allocator + + f = alias or self.target.c0 + + fixed = self._fixed + + # Build a mapper `(dim, side, region) -> (size)` for `f`. + mapper = {} + for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): + if d0 in fixed: + continue + sizes = [] + for d1 in f.dimensions: + if d1 in fixed: + continue + if d0 is d1: + if region is OWNED: + sizes.append(getattr(f._size_owned[d0], side.name)) + elif region is HALO: + sizes.append(getattr(f._size_halo[d0], side.name)) + else: + sizes.append(self._as_number(f._size_nopad[d1], args)) + mapper[(d0, side, region)] = sizes + + i = 0 + for d in f.dimensions: + if d in fixed: + continue + + if (d, LEFT) in self.halos: + entry = self.value[i] + i = i + 1 + # Sending to left, receiving from right + shape = mapper[(d, LEFT, OWNED)] + # Allocate the send/recv buffers + self._allocate_buffers(f, shape, entry) + + if (d, RIGHT) in self.halos: + entry = self.value[i] + i = i + 1 + # Sending to right, receiving from left + shape = mapper[(d, RIGHT, OWNED)] + # Allocate the send/recv buffers + self._allocate_buffers(f, shape, entry) + + return {self.name: self.value} + + class MPIMsgEnriched(MPIMsg): _C_field_ofss = 'ofss' diff --git a/tests/test_mpi.py b/tests/test_mpi.py index b6332a476c..c60fe6f2cd 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -778,7 +778,7 @@ def test_trivial_eq_1d_save(self, mode): else: assert np.all(f.data_ro_domain[-1, :-time_M] == 31.) - @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'diag'), (4, 'overlap'), + @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'basic2'), (4, 'diag'), (4, 'overlap'), (4, 'overlap2'), (4, 'diag2'), (4, 'full')]) def test_trivial_eq_2d(self, mode): grid = Grid(shape=(8, 8,)) @@ -814,7 +814,7 @@ def test_trivial_eq_2d(self, mode): assert np.all(f.data_ro_domain[0, :-1, -1:] == side) assert np.all(f.data_ro_domain[0, -1:, :-1] == side) - @pytest.mark.parallel(mode=[(8, 'basic'), (8, 'diag'), (8, 'overlap'), + @pytest.mark.parallel(mode=[(8, 'basic2'), (8, 'diag'), (8, 'overlap'), (8, 'overlap2'), (8, 'diag2'), (8, 'full')]) def test_trivial_eq_3d(self, mode): grid = Grid(shape=(8, 8, 8)) @@ -1539,6 +1539,7 @@ def test_diag2_quality(self, mode): @pytest.mark.parallel(mode=[ (1, 'basic'), + (1, 'basic2'), (1, 'diag'), (1, 'overlap'), (1, 'overlap2'), @@ -1565,6 +1566,11 @@ def test_min_code_size(self, mode): assert len(calls) == 1 assert calls[0].name == 'haloupdate0' assert calls[0].ncomps == 2 + elif configuration['mpi'] in ('basic2'): + assert len(op._func_table) == 4 + assert len(calls) == 1 # haloupdate + assert calls[0].name == 'haloupdate0' + assert 'haloupdate1' not in op._func_table elif configuration['mpi'] in ('overlap'): assert len(op._func_table) == 8 assert len(calls) == 4 # haloupdate, compute, halowait, remainder @@ -2097,7 +2103,7 @@ def test_nontrivial_operator(self, mode): if not glb_pos_map[x] and not glb_pos_map[y]: assert np.all(u.data_ro_domain[1] == 3) - @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'overlap'), (4, 'full')]) + @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'basic2'), (4, 'overlap'), (4, 'full')]) def test_coupled_eqs_mixed_dims(self, mode): """ Test an Operator that computes coupled equations over partly disjoint sets @@ -2712,7 +2718,7 @@ def run_adjoint_F(self, nd): assert np.isclose((term1 - term2)/term1, 0., rtol=1.e-10) @pytest.mark.parametrize('nd', [1, 2, 3]) - @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'diag'), (4, 'overlap'), + @pytest.mark.parallel(mode=[(4, 'basic2'), (4, 'diag'), (4, 'overlap'), (4, 'overlap2'), (4, 'full')]) def test_adjoint_F(self, nd, mode): self.run_adjoint_F(nd)