Source code for veros.distributed

import functools

from veros import runtime_settings as rs, runtime_state as rst
from veros.routines import CURRENT_CONTEXT

SCATTERED_DIMENSIONS = (("xt", "xu"), ("yt", "yu"))


def dist_context_only(function=None, *, noop_return_arg=None):
    def decorator(function):
        @functools.wraps(function)
        def dist_context_only_wrapper(*args, **kwargs):
            if rst.proc_num == 1 or not CURRENT_CONTEXT.is_dist_safe:
                # no-op for sequential execution
                if noop_return_arg is None:
                    return None

                # return input array unchanged
                return args[noop_return_arg]

            return function(*args, **kwargs)

        return dist_context_only_wrapper

    if function is not None:
        return decorator(function)

    return decorator


def send(buf, dest, comm, tag=None):
    kwargs = {}
    if tag is not None:
        kwargs.update(tag=tag)

    if rs.backend == "jax":
        from mpi4jax import send

        token = CURRENT_CONTEXT.mpi4jax_token
        new_token = send(buf, dest=dest, comm=comm, token=token, **kwargs)
        CURRENT_CONTEXT.mpi4jax_token = new_token
    else:
        comm.Send(ascontiguousarray(buf), dest=dest, **kwargs)


def recv(buf, source, comm, tag=None):
    kwargs = {}
    if tag is not None:
        kwargs.update(tag=tag)

    if rs.backend == "jax":
        from mpi4jax import recv

        token = CURRENT_CONTEXT.mpi4jax_token
        buf, new_token = recv(buf, source=source, comm=comm, token=token, **kwargs)
        CURRENT_CONTEXT.mpi4jax_token = new_token
        return buf

    buf = buf.copy()
    comm.Recv(buf, source=source, **kwargs)
    return buf


def sendrecv(sendbuf, recvbuf, source, dest, comm, sendtag=None, recvtag=None):
    kwargs = {}

    if sendtag is not None:
        kwargs.update(sendtag=sendtag)

    if recvtag is not None:
        kwargs.update(recvtag=recvtag)

    if rs.backend == "jax":
        from mpi4jax import sendrecv

        token = CURRENT_CONTEXT.mpi4jax_token
        recvbuf, new_token = sendrecv(sendbuf, recvbuf, source=source, dest=dest, comm=comm, token=token, **kwargs)
        CURRENT_CONTEXT.mpi4jax_token = new_token
        return recvbuf

    recvbuf = recvbuf.copy()
    comm.Sendrecv(sendbuf=ascontiguousarray(sendbuf), recvbuf=recvbuf, source=source, dest=dest, **kwargs)
    return recvbuf


def bcast(buf, comm, root=0):
    if rs.backend == "jax":
        from mpi4jax import bcast

        token = CURRENT_CONTEXT.mpi4jax_token
        buf, new_token = bcast(buf, root=root, comm=comm, token=token)
        CURRENT_CONTEXT.mpi4jax_token = new_token
        return buf

    return comm.bcast(buf, root=root)


def allreduce(buf, op, comm):
    if rs.backend == "jax":
        from mpi4jax import allreduce

        token = CURRENT_CONTEXT.mpi4jax_token
        buf, new_token = allreduce(buf, op=op, comm=comm, token=token)
        CURRENT_CONTEXT.mpi4jax_token = new_token
        return buf

    from veros.core.operators import numpy as npx

    recvbuf = npx.empty_like(buf)
    comm.Allreduce(ascontiguousarray(buf), recvbuf, op=op)
    return recvbuf


def ascontiguousarray(arr):
    assert rs.backend == "numpy"
    import numpy

    return numpy.ascontiguousarray(arr)


def validate_decomposition(dimensions):
    nx, ny = dimensions["xt"], dimensions["yt"]

    if rs.mpi_comm is None:
        if rs.num_proc[0] > 1 or rs.num_proc[1] > 1:
            raise RuntimeError("mpi4py is required for distributed execution")
        return

    comm_size = rs.mpi_comm.Get_size()
    proc_num = rs.num_proc[0] * rs.num_proc[1]
    if proc_num != comm_size:
        raise RuntimeError(f"number of processes ({proc_num}) does not match size of communicator ({comm_size})")

    if nx % rs.num_proc[0]:
        raise ValueError("processes do not divide domain evenly in x-direction")

    if ny % rs.num_proc[1]:
        raise ValueError("processes do not divide domain evenly in y-direction")


def get_chunk_size(nx, ny):
    return (nx // rs.num_proc[0], ny // rs.num_proc[1])


def proc_rank_to_index(rank):
    return (rank % rs.num_proc[0], rank // rs.num_proc[0])


def proc_index_to_rank(ix, iy):
    return ix + iy * rs.num_proc[0]


def get_chunk_slices(nx, ny, dim_grid, proc_idx=None, include_overlap=False):
    if not dim_grid:
        return Ellipsis, Ellipsis

    if proc_idx is None:
        proc_idx = proc_rank_to_index(rst.proc_rank)

    px, py = proc_idx
    nxl, nyl = get_chunk_size(nx, ny)

    if include_overlap:
        sxl = 0 if px == 0 else 2
        sxu = nxl + 4 if (px + 1) == rs.num_proc[0] else nxl + 2
        syl = 0 if py == 0 else 2
        syu = nyl + 4 if (py + 1) == rs.num_proc[1] else nyl + 2
    else:
        sxl = syl = 0
        sxu = nxl
        syu = nyl

    global_slice, local_slice = [], []

    for dim in dim_grid:
        if dim in SCATTERED_DIMENSIONS[0]:
            global_slice.append(slice(sxl + px * nxl, sxu + px * nxl))
            local_slice.append(slice(sxl, sxu))
        elif dim in SCATTERED_DIMENSIONS[1]:
            global_slice.append(slice(syl + py * nyl, syu + py * nyl))
            local_slice.append(slice(syl, syu))
        else:
            global_slice.append(slice(None))
            local_slice.append(slice(None))

    return tuple(global_slice), tuple(local_slice)


def get_process_neighbors(cyclic=False):
    this_x, this_y = proc_rank_to_index(rst.proc_rank)

    if this_x == 0:
        if cyclic:
            west = rs.num_proc[0] - 1
        else:
            west = None
    else:
        west = this_x - 1

    if this_x == rs.num_proc[0] - 1:
        if cyclic:
            east = 0
        else:
            east = None
    else:
        east = this_x + 1

    south = this_y - 1 if this_y != 0 else None
    north = this_y + 1 if this_y != (rs.num_proc[1] - 1) else None

    neighbors = dict(
        # direct neighbors
        west=(west, this_y),
        south=(this_x, south),
        east=(east, this_y),
        north=(this_x, north),
        # corners
        southwest=(west, south),
        southeast=(east, south),
        northeast=(east, north),
        northwest=(west, north),
    )

    global_neighbors = {k: proc_index_to_rank(*i) if None not in i else None for k, i in neighbors.items()}
    return global_neighbors


@dist_context_only(noop_return_arg=0)
def exchange_overlap(arr, var_grid, cyclic):
    from veros.core.operators import numpy as npx, update, at

    # start west, go clockwise
    send_order = (
        "west",
        "northwest",
        "north",
        "northeast",
        "east",
        "southeast",
        "south",
        "southwest",
    )

    # start east, go clockwise
    recv_order = (
        "east",
        "southeast",
        "south",
        "southwest",
        "west",
        "northwest",
        "north",
        "northeast",
    )

    if len(var_grid) < 2:
        d1, d2 = var_grid[0], None
    else:
        d1, d2 = var_grid[:2]

    if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]:
        # neither x nor y dependent, nothing to do
        return arr

    proc_neighbors = get_process_neighbors(cyclic)

    if d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]:
        overlap_slices_from = dict(
            west=(slice(2, 4), slice(0, None), Ellipsis),
            south=(slice(0, None), slice(2, 4), Ellipsis),
            east=(slice(-4, -2), slice(0, None), Ellipsis),
            north=(slice(0, None), slice(-4, -2), Ellipsis),
            southwest=(slice(2, 4), slice(2, 4), Ellipsis),
            southeast=(slice(-4, -2), slice(2, 4), Ellipsis),
            northeast=(slice(-4, -2), slice(-4, -2), Ellipsis),
            northwest=(slice(2, 4), slice(-4, -2), Ellipsis),
        )

        overlap_slices_to = dict(
            west=(slice(0, 2), slice(0, None), Ellipsis),
            south=(slice(0, None), slice(0, 2), Ellipsis),
            east=(slice(-2, None), slice(0, None), Ellipsis),
            north=(slice(0, None), slice(-2, None), Ellipsis),
            southwest=(slice(0, 2), slice(0, 2), Ellipsis),
            southeast=(slice(-2, None), slice(0, 2), Ellipsis),
            northeast=(slice(-2, None), slice(-2, None), Ellipsis),
            northwest=(slice(0, 2), slice(-2, None), Ellipsis),
        )

    else:
        if d1 in SCATTERED_DIMENSIONS[0]:
            send_order = ("west", "east")
            recv_order = ("east", "west")
        elif d1 in SCATTERED_DIMENSIONS[1]:
            send_order = ("north", "south")
            recv_order = ("south", "north")
        else:
            raise NotImplementedError()

        overlap_slices_from = dict(
            west=(slice(2, 4), Ellipsis),
            south=(slice(2, 4), Ellipsis),
            east=(slice(-4, -2), Ellipsis),
            north=(slice(-4, -2), Ellipsis),
        )

        overlap_slices_to = dict(
            west=(slice(0, 2), Ellipsis),
            south=(slice(0, 2), Ellipsis),
            east=(slice(-2, None), Ellipsis),
            north=(slice(-2, None), Ellipsis),
        )

    for send_dir, recv_dir in zip(send_order, recv_order):
        send_proc = proc_neighbors[send_dir]
        recv_proc = proc_neighbors[recv_dir]

        if send_proc is None and recv_proc is None:
            continue

        recv_idx = overlap_slices_to[recv_dir]
        recv_arr = npx.empty_like(arr[recv_idx])

        send_idx = overlap_slices_from[send_dir]
        send_arr = arr[send_idx]

        if send_proc is None:
            recv_arr = recv(recv_arr, recv_proc, rs.mpi_comm)
            arr = update(arr, at[recv_idx], recv_arr)
        elif recv_proc is None:
            send(send_arr, send_proc, rs.mpi_comm)
        else:
            recv_arr = sendrecv(send_arr, recv_arr, source=recv_proc, dest=send_proc, comm=rs.mpi_comm)
            arr = update(arr, at[recv_idx], recv_arr)

    return arr


def _memoize(function):
    cached = {}

    @functools.wraps(function)
    def memoized(*args):
        from mpi4py import MPI

        # MPI Comms are not hashable, so we use the underlying handle instead
        cache_args = tuple(MPI._handleof(arg) if isinstance(arg, MPI.Comm) else arg for arg in args)

        if cache_args not in cached:
            cached[cache_args] = function(*args)

        return cached[cache_args]

    return memoized


@_memoize
def _mpi_comm_along_axis(comm, procs, rank):
    return comm.Split(procs, rank)


@dist_context_only(noop_return_arg=0)
def _reduce(arr, op, axis=None):
    from veros.core.operators import numpy as npx

    if axis is None:
        comm = rs.mpi_comm
    else:
        assert axis in (0, 1)
        pi = proc_rank_to_index(rst.proc_rank)
        other_axis = 1 - axis
        comm = _mpi_comm_along_axis(rs.mpi_comm, pi[other_axis], rst.proc_rank)

    if npx.isscalar(arr):
        squeeze = True
        arr = npx.array([arr])
    else:
        squeeze = False

    res = allreduce(arr, op=op, comm=comm)

    if squeeze:
        res = res[0]

    return res


@dist_context_only(noop_return_arg=0)
def global_and(arr, axis=None):
    from mpi4py import MPI

    return _reduce(arr, MPI.LAND, axis=axis)


@dist_context_only(noop_return_arg=0)
def global_or(arr, axis=None):
    from mpi4py import MPI

    return _reduce(arr, MPI.LOR, axis=axis)


[docs] @dist_context_only(noop_return_arg=0) def global_max(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.MAX, axis=axis)
@dist_context_only(noop_return_arg=0) def global_min(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.MIN, axis=axis)
[docs] @dist_context_only(noop_return_arg=0) def global_sum(arr, axis=None): from mpi4py import MPI return _reduce(arr, MPI.SUM, axis=axis)
@dist_context_only(noop_return_arg=2) def _gather_1d(nx, ny, arr, dim): from veros.core.operators import numpy as npx, update, at assert dim in (0, 1) otherdim = 1 - dim pi = proc_rank_to_index(rst.proc_rank) if pi[otherdim] != 0: return arr dim_grid = ["xt" if dim == 0 else "yt"] + [None] * (arr.ndim - 1) gidx, idx = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) sendbuf = arr[idx] if rst.proc_rank == 0: buffer_list = [] for proc in range(1, rst.proc_num): pi = proc_rank_to_index(proc) if pi[otherdim] != 0: continue idx_g, idx_l = get_chunk_slices(nx, ny, dim_grid, include_overlap=True, proc_idx=pi) recvbuf = npx.empty_like(arr[idx_l]) recvbuf = recv(recvbuf, source=proc, tag=20, comm=rs.mpi_comm) buffer_list.append((idx_g, recvbuf)) out_shape = ((nx + 4, ny + 4)[dim],) + arr.shape[1:] out = npx.empty(out_shape, dtype=arr.dtype) out = update(out, at[gidx], sendbuf) for idx, val in buffer_list: out = update(out, at[idx], val) return out else: send(sendbuf, dest=0, tag=20, comm=rs.mpi_comm) return arr @dist_context_only(noop_return_arg=2) def _gather_xy(nx, ny, arr): from veros.core.operators import numpy as npx, update, at nxi, nyi = get_chunk_size(nx, ny) assert arr.shape[:2] == (nxi + 4, nyi + 4), arr.shape dim_grid = ["xt", "yt"] + [None] * (arr.ndim - 2) gidx, idx = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) sendbuf = arr[idx] if rst.proc_rank == 0: buffer_list = [] for proc in range(1, rst.proc_num): idx_g, idx_l = get_chunk_slices(nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc)) recvbuf = npx.empty_like(arr[idx_l]) recvbuf = recv(recvbuf, source=proc, tag=30, comm=rs.mpi_comm) buffer_list.append((idx_g, recvbuf)) out_shape = (nx + 4, ny + 4) + arr.shape[2:] out = npx.empty(out_shape, dtype=arr.dtype) out = update(out, at[gidx], sendbuf) for idx, val in buffer_list: out = update(out, at[idx], val) return out send(sendbuf, dest=0, tag=30, comm=rs.mpi_comm) return arr
[docs] @dist_context_only(noop_return_arg=0) def gather(arr, dimensions, var_grid): nx, ny = dimensions["xt"], dimensions["yt"] if var_grid is None: return arr if len(var_grid) < 2: d1, d2 = var_grid[0], None else: d1, d2 = var_grid[:2] if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]: # neither x nor y dependent, nothing to do return arr if d1 in SCATTERED_DIMENSIONS[0] and d2 not in SCATTERED_DIMENSIONS[1]: # only x dependent return _gather_1d(nx, ny, arr, 0) elif d1 in SCATTERED_DIMENSIONS[1]: # only y dependent return _gather_1d(nx, ny, arr, 1) elif d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]: # x and y dependent return _gather_xy(nx, ny, arr) else: raise NotImplementedError()
@dist_context_only(noop_return_arg=0) def _scatter_constant(arr): return bcast(arr, rs.mpi_comm, root=0) @dist_context_only(noop_return_arg=2) def _scatter_1d(nx, ny, arr, dim): from veros.core.operators import numpy as npx, update, at assert dim in (0, 1) out_nx = get_chunk_size(nx, ny)[dim] dim_grid = ["xt" if dim == 0 else "yt"] + [None] * (arr.ndim - 1) _, local_slice = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) if rst.proc_rank == 0: recvbuf = arr[local_slice] for proc in range(1, rst.proc_num): global_slice, _ = get_chunk_slices( nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc) ) sendbuf = arr[global_slice] send(sendbuf, dest=proc, tag=40, comm=rs.mpi_comm) # arr changes shape in main process arr = npx.zeros((out_nx + 4,) + arr.shape[1:], dtype=arr.dtype) else: recvbuf = recv(arr[local_slice], source=0, tag=40, comm=rs.mpi_comm) arr = update(arr, at[local_slice], recvbuf) arr = exchange_overlap(arr, ["xt" if dim == 0 else "yt"], cyclic=False) return arr @dist_context_only(noop_return_arg=2) def _scatter_xy(nx, ny, arr): from veros.core.operators import numpy as npx, update, at nxi, nyi = get_chunk_size(nx, ny) dim_grid = ["xt", "yt"] + [None] * (arr.ndim - 2) _, local_slice = get_chunk_slices(nx, ny, dim_grid, include_overlap=True) if rst.proc_rank == 0: recvbuf = arr[local_slice] for proc in range(1, rst.proc_num): global_slice, _ = get_chunk_slices( nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc) ) sendbuf = arr[global_slice] send(sendbuf, dest=proc, tag=50, comm=rs.mpi_comm) # arr changes shape in main process arr = npx.empty((nxi + 4, nyi + 4) + arr.shape[2:], dtype=arr.dtype) else: recvbuf = npx.empty_like(arr[local_slice]) recvbuf = recv(recvbuf, source=0, tag=50, comm=rs.mpi_comm) arr = update(arr, at[local_slice], recvbuf) arr = exchange_overlap(arr, ["xt", "yt"], cyclic=False) return arr
[docs] @dist_context_only(noop_return_arg=0) def scatter(arr, dimensions, var_grid): from veros.core.operators import numpy as npx if var_grid is None: return _scatter_constant(arr) nx, ny = dimensions["xt"], dimensions["yt"] if len(var_grid) < 2: d1, d2 = var_grid[0], None else: d1, d2 = var_grid[:2] arr = npx.asarray(arr) if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]: # neither x nor y dependent return _scatter_constant(arr) if d1 in SCATTERED_DIMENSIONS[0] and d2 not in SCATTERED_DIMENSIONS[1]: # only x dependent return _scatter_1d(nx, ny, arr, 0) elif d1 in SCATTERED_DIMENSIONS[1]: # only y dependent return _scatter_1d(nx, ny, arr, 1) elif d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]: # x and y dependent return _scatter_xy(nx, ny, arr) else: raise NotImplementedError("unreachable")
[docs] @dist_context_only def barrier(): rs.mpi_comm.barrier()
[docs] @dist_context_only def abort(): rs.mpi_comm.Abort()