Skip to content

Commit

Permalink
mpi: patch missing dimensions for halospot
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 9, 2023
1 parent ef3b03c commit 3c305a2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 8 deletions.
5 changes: 4 additions & 1 deletion devito/ir/iet/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def iet_build(stree):
nsections += 1

elif i.is_Halo:
body = HaloSpot(queues.pop(i), i.halo_scheme)
try:
body = HaloSpot(queues.pop(i), i.halo_scheme)
except KeyError:
body = HaloSpot(None, i.halo_scheme)

elif i.is_Sync:
body = SyncSpot(i.sync_ops, body=queues.pop(i, None))
Expand Down
23 changes: 21 additions & 2 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ def stree_build(clusters, profiler=None, **kwargs):
tip = augment_whole_subtree(c, tip, mapper, it)

# Attach NodeHalo if necessary
for it, v in mapper.items():
for it, v in reversed(mapper.items()):
if needs_nodehalo(it.dim, c.halo_scheme):
v.bottom.parent = NodeHalo(c.halo_scheme, v.bottom.parent)
break
if needs_nodehalo_dim(it.dim, c.halo_scheme):
v.bottom.children = [NodeHalo(c.halo_scheme, v.bottom.parent)]
break

# Add in NodeExprs
exprs = []
Expand Down Expand Up @@ -182,7 +185,19 @@ def preprocess(clusters, options=None, **kwargs):
processed.append(c.rebuild(exprs=[], ispace=ispace, syncs=syncs))

halo_scheme = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=halo_scheme))
if halo_scheme:
itdims = set(c.ispace.dimensions)
hispace = c.ispace.project(itdims - halo_scheme.distributed_aindices)
else:
hispace = None

if hispace:
processed.append(c.rebuild(exprs=[],
ispace=hispace,
halo_scheme=halo_scheme))
processed.append(c.rebuild(halo_scheme=None))
else:
processed.append(c.rebuild(halo_scheme=halo_scheme))

# Sanity check!
try:
Expand Down Expand Up @@ -229,6 +244,10 @@ def needs_nodehalo(d, hs):
return d and hs and d._defines.intersection(hs.distributed_aindices)


def needs_nodehalo_dim(d, hs):
return d and hs and d._defines.intersection(hs.loc_indices)


def reuse_section(candidate, section):
try:
if not section or candidate.siblings[-1] is not section:
Expand Down
2 changes: 1 addition & 1 deletion devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _augment_implicit_dims(self, implicit_dims, extras=None):
idims = self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
idims = extra + as_tuple(implicit_dims) + self.sfunction.dimensions
return tuple(filter_ordered(idims))
return tuple(idims)

def _coeff_temps(self, implicit_dims):
return []
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def optimize_halospots(iet, **kwargs):
around in order to improve the halo exchange performance.
"""
iet = _drop_halospots(iet)
iet = _hoist_halospots(iet)
# iet = _hoist_halospots(iet)
iet = _merge_halospots(iet)
iet = _drop_if_unwritten(iet, **kwargs)
iet = _mark_overlappable(iet)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ class SparseFirst(SparseFunction):
ds = DefaultDimension("ps", default_value=3)
grid = Grid((11, 11))
dims = grid.dimensions
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3))
s.coordinates.data[:] = [[.5, .5], [.2, .2]]
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3),
coordinates=[[.5, .5], [.2, .2]])

# Check dimensions and shape are correctly initialized
assert s.indices[s._sparse_position] == dr
Expand Down
36 changes: 35 additions & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,41 @@ def test_precomputed_sparse(self, r):
Operator(sf1.interpolate(u))()
assert np.all(sf1.data == 4)

@pytest.mark.parallel(mode=4)
def test_sparse_first(self):
"""
Tests custom sprase function with sparse dimension as first index.
"""

class SparseFirst(SparseFunction):
""" Custom sparse class with the sparse dimension as the first one"""
_sparse_position = 0

dr = Dimension("cd")
ds = DefaultDimension("ps", default_value=3)
grid = Grid((11, 11))
dims = grid.dimensions
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3),
coordinates=[[.5, .5], [.2, .2]])

# Check dimensions and shape are correctly initialized
assert s.indices[s._sparse_position] == dr
assert s.shape == (2, 3)
assert s.coordinates.indices[0] == dr

# Operator
u = TimeFunction(name="u", grid=grid, time_order=1)
fs = Function(name="fs", grid=grid, dimensions=(*dims, ds), shape=(11, 11, 3))

eqs = [Eq(u.forward, u+1), Eq(fs, u)]
# No time dependence so need the implicit dim
rec = s.interpolate(expr=s+fs, implicit_dims=grid.stepping_dim)
op = Operator(eqs + rec)

op(time_M=10)
expected = 10*11/2 # n (n+1)/2
assert np.allclose(s.data, expected)

@pytest.mark.parallel(mode=4)
def test_no_grid_dim_slow(self):
shape = (12, 13, 14)
Expand All @@ -624,7 +659,6 @@ class CoordSlowSparseFunction(SparseFunction):
rec_eq = s.interpolate(expr=u)

op = Operator([Eq(u, 1)] + rec_eq)
print(op)
op.apply()
assert np.all(s.data == 1)

Expand Down

0 comments on commit 3c305a2

Please sign in to comment.