Skip to content

Commit

Permalink
avoiding buffer allocations in each time step in MPI boundary conditi…
Browse files Browse the repository at this point in the history
…ons, closes #53 (#61)

* initial changes

* adapt to a new API

* Update pyproject.toml

* remove numpy

* precommit

* force older Numpy to avoid NaN-cast warnings in PyMPDATA fields ctors

* disable kwargs

* comment non-supported option

* downgrade numba

* try passing dtype to view()

* code formatting fix

* add missing dtype params

* fix shape setting for Numba

* add complex buffer support and uncomment back non-oscillatory tests

* use chank assignement instead of reshape to ensure nojit tests fail if mem needs to be copied

* silence pylint on numba.config access

* make the code compilable

---------

Co-authored-by: derlk <[email protected]>
Co-authored-by: Sylwester Arabas <[email protected]>
  • Loading branch information
3 people authored May 16, 2023
1 parent 3ae6bbf commit 3a62966
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 35 deletions.
83 changes: 51 additions & 32 deletions PyMPDATA_MPI/periodic.py → PyMPDATA_MPI/mpi_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numba
import numba_mpi as mpi
import numpy as np
from mpi4py import MPI
from PyMPDATA.boundary_conditions import Periodic
from PyMPDATA.impl.enumerations import INVALID_INDEX, SIGN_LEFT, SIGN_RIGHT
Expand All @@ -30,7 +29,9 @@ def make_scalar(self, indexers, halo, dtype, jit_flags, dimension_index):
return Periodic.make_scalar(
indexers, halo, dtype, jit_flags, dimension_index
)
return _make_scalar_periodic(indexers, jit_flags, dimension_index, self.__size)
return _make_scalar_periodic(
indexers, jit_flags, dimension_index, self.__size, dtype
)

def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index):
"""returns (lru-cached) Numba-compiled vector halo-filling callable"""
Expand All @@ -39,34 +40,27 @@ def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index):
indexers, halo, dtype, jit_flags, dimension_index
)
return _make_vector_periodic(
indexers, halo, jit_flags, dimension_index, self.__size
indexers, halo, jit_flags, dimension_index, self.__size, dtype
)


def _make_send_recv(set_value, jit_flags, fill_buf):
def _make_send_recv(set_value, jit_flags, fill_buf, size, dtype):
@numba.njit(**jit_flags)
def _send_recv(size, psi, i_rng, j_rng, k_rng, sign, dim, output):
buf = np.empty(
(
len(i_rng),
len(k_rng),
),
dtype=output.dtype,
)
def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index):
chunk_size = len(i_rng) * len(k_rng)
return buffer.view(dtype)[
chunk_index * chunk_size : (chunk_index + 1) * chunk_size
].reshape((len(i_rng), len(k_rng)))

@numba.njit(**jit_flags)
def get_peers():
rank = mpi.rank()
peers = (-1, (rank - 1) % size, (rank + 1) % size) # LEFT # RIGHT

if SIGN_LEFT == sign:
fill_buf(buf, psi, i_rng, k_rng, sign, dim)
mpi.send(buf, dest=peers[sign])
mpi.recv(buf, source=peers[sign])
elif SIGN_RIGHT == sign:
mpi.recv(buf, source=peers[sign])
tmp = np.empty_like(buf)
fill_buf(tmp, psi, i_rng, k_rng, sign, dim)
mpi.send(tmp, dest=peers[sign])
left_peer = (rank - 1) % size
right_peer = (rank + 1) % size
return (-1, left_peer, right_peer)

@numba.njit(**jit_flags)
def fill_output(output, buffer, i_rng, j_rng, k_rng):
for i in i_rng:
for j in j_rng:
for k in k_rng:
Expand All @@ -75,14 +69,39 @@ def _send_recv(size, psi, i_rng, j_rng, k_rng, sign, dim, output):
i,
j,
k,
buf[i - i_rng.start, k - k_rng.start],
buffer[i - i_rng.start, k - k_rng.start],
)

@numba.njit(**jit_flags)
def _send(buf, peer, fill_buf_args):
fill_buf(buf, *fill_buf_args)
mpi.send(buf, dest=peer)

@numba.njit(**jit_flags)
def _recv(buf, peer):
mpi.recv(buf, source=peer)

@numba.njit(**jit_flags)
def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output):
buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0)
peers = get_peers()
fill_buf_args = (psi, i_rng, k_rng, sign, dim)

if SIGN_LEFT == sign:
_send(buf=buf, peer=peers[sign], fill_buf_args=fill_buf_args)
_recv(buf=buf, peer=peers[sign])
elif SIGN_RIGHT == sign:
_recv(buf=buf, peer=peers[sign])
tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1)
_send(buf=tmp, peer=peers[sign], fill_buf_args=fill_buf_args)

fill_output(output, buf, i_rng, j_rng, k_rng)

return _send_recv


@lru_cache()
def _make_scalar_periodic(indexers, jit_flags, dimension_index, size):
def _make_scalar_periodic(indexers, jit_flags, dimension_index, size, dtype):
@numba.njit(**jit_flags)
def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
for i in i_rng:
Expand All @@ -91,17 +110,17 @@ def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
(i, INVALID_INDEX, k), psi, sign
)

send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf)
send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf, size, dtype)

@numba.njit(**jit_flags)
def fill_halos(i_rng, j_rng, k_rng, psi, _, sign):
send_recv(size, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)

return fill_halos


@lru_cache()
def _make_vector_periodic(indexers, halo, jit_flags, dimension_index, size):
def _make_vector_periodic(indexers, halo, jit_flags, dimension_index, size, dtype):
@numba.njit(**jit_flags)
def fill_buf(buf, components, i_rng, k_rng, sign, dim):
parallel = dim % len(components) == dimension_index
Expand All @@ -119,12 +138,12 @@ def fill_buf(buf, components, i_rng, k_rng, sign, dim):

buf[i - i_rng.start, k - k_rng.start] = value

send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf)
send_recv = _make_send_recv(indexers.set, jit_flags, fill_buf, size, dtype)

@numba.njit(**jit_flags)
def fill_halos_loop_vector(i_rng, j_rng, k_rng, components, dim, _, sign):
def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
return
send_recv(size, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])

return fill_halos_loop_vector
6 changes: 5 additions & 1 deletion PyMPDATA_MPI/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PyMPDATA.boundary_conditions import Periodic

from .domain_decomposition import mpi_indices
from .periodic import MPIPeriodic
from .mpi_periodic import MPIPeriodic


class Simulation:
Expand Down Expand Up @@ -45,6 +45,10 @@ def __init__(
n_dims=2,
n_threads=n_threads,
left_first=rank % 2 == 0,
# TODO https://github.com/open-atmos/PyMPDATA/issues/386
buffer_size=((ny + 2 * halo) * halo)
* 2 # for temporary send/recv buffer on one side
* 2, # for complex dtype
)
self.solver = Solver(stepper=stepper, advectee=self.advectee, advector=advector)

Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ classifiers = [
"Topic :: Scientific/Engineering :: Physics"
]
dependencies = [
# TODO: these should be handled within PyMPDATA?
"numba<0.57.0",
"numpy<1.24.0",
"numba_mpi>=0.30",
"PyMPDATA==1.0.10",
"PyMPDATA==1.0.11",
"mpi4py",
"h5py",
"pytest-mpi"
"pytest-mpi" # TODO: move it to optional dependencies (extras_require?)
]
dynamic = ["version"]

0 comments on commit 3a62966

Please sign in to comment.