Source code for veros.routines

import functools
import inspect
import threading
from contextlib import ExitStack, contextmanager

from veros import logger

from veros.state import VerosState


# stack helpers


class RoutineStack:
    def __init__(self):
        self.keep_full_stack = False
        self._stack = []
        self._current_idx = []

    @property
    def stack_level(self):
        return len(self._current_idx)

    def append(self, val):
        frame = self._stack
        for i in self._current_idx:
            frame = frame[i][1]

        self._current_idx.append(len(frame))
        frame.append([val, []])

    def pop(self):
        frame = self._stack
        for i in self._current_idx[:-1]:
            frame = frame[i][1]

        if self.keep_full_stack:
            last_val = frame[-1][0]
        else:
            last_val = frame.pop()[0]
        self._current_idx.pop()
        return last_val


# global context

CURRENT_CONTEXT = threading.local()
CURRENT_CONTEXT.is_dist_safe = True
CURRENT_CONTEXT.routine_stack = RoutineStack()
CURRENT_CONTEXT.mpi4jax_token = None


@contextmanager
def nullcontext():
    yield


@contextmanager
def enter_routine(name, routine_obj, timer=None, dist_safe=True):
    from veros import runtime_state as rst
    from veros.distributed import abort

    stack = CURRENT_CONTEXT.routine_stack

    logger.trace("{}> {}", "-" * stack.stack_level, name)
    stack.append(routine_obj)

    reset_dist_safe = False
    if CURRENT_CONTEXT.is_dist_safe:
        if not dist_safe and rst.proc_num > 1:
            CURRENT_CONTEXT.is_dist_safe = False
            reset_dist_safe = True

    timer_ctx = nullcontext() if timer is None else timer

    try:
        with timer_ctx:
            yield

    except:  # noqa: E722
        if reset_dist_safe:
            abort()
        raise

    finally:
        if reset_dist_safe:
            CURRENT_CONTEXT.is_dist_safe = True

        r = stack.pop()
        assert r is routine_obj

        exec_time = ""
        if timer is not None:
            exec_time = f"({timer.last_time:.3f}s)"

        logger.trace("<{} {} {}", "-" * stack.stack_level, name, exec_time)


# helper functions


def _get_func_name(function):
    return f"{inspect.getmodule(function).__name__}:{function.__qualname__}"


def _is_method(function):
    if inspect.ismethod(function):
        return True

    # hack for unbound methods: check if first argument is called "self"
    spec = inspect.getfullargspec(function)
    return spec.args and spec.args[0] == "self"


# routine


[docs]def veros_routine(function=None, *, dist_safe=True, local_variables=()): """ .. note:: This decorator should be applied to all functions that access the Veros state object (even when subclassing :class:`veros.VerosSetup`). The first argument to the decorated function must be a VerosState instance. Veros routines cannot return anything. All changes must be applied to the passed state object. Parameters: dist_safe (bool): If set to False, all variables specified in local_variables are synced to the root process before execution and synced back after. This means that the routine will only be executed on rank 0. Has no effect in non-distributed contexts. local_variables (Tuple[str]): List of variable names to be synced if dist_safe=False. This must include all variables retrieved from the state object throughout the routine (inputs *and* outputs). Example: >>> from veros import VerosSetup, veros_routine >>> >>> class MyModel(VerosSetup): >>> @veros_routine >>> def set_topography(self, state): >>> vs = state.variables >>> settings = state.settings >>> vs.kbot = npx.random.randint(0, settings.nz, size=vs.kbot.shape) """ def inner_decorator(function): narg = 1 if _is_method(function) else 0 num_params = len(inspect.signature(function).parameters) if narg >= num_params: raise TypeError("Veros routines must take at least one argument") routine = VerosRoutine(function, state_argnum=narg, dist_safe=dist_safe, local_variables=local_variables) routine = functools.wraps(function)(routine) return routine if function is not None: return inner_decorator(function) return inner_decorator
class VerosRoutine: """Do not instantiate directly!""" def __init__(self, function, dist_safe=True, local_variables=(), state_argnum=0): if isinstance(local_variables, str): local_variables = (local_variables,) self.function = function self.dist_safe = dist_safe self.local_variables = local_variables self.state_argnum = state_argnum self.name = _get_func_name(self.function) def __call__(self, *args, **kwargs): from veros import runtime_state as rst from veros.state import VerosState, DistSafeVariableWrapper from veros.core.operators import flush veros_state = args[self.state_argnum] if not isinstance(veros_state, VerosState): raise TypeError(f"Argument {self.state_argnum} to this Veros routine must be a VerosState object") timer = veros_state.profile_timers[self.name] with ExitStack() as es: vars_initialized = veros_state._variables is not None if vars_initialized: es.enter_context(veros_state.variables.unlock()) execute = True restore_vars = False if not self.dist_safe: orig_vars = veros_state._variables if not isinstance(orig_vars, DistSafeVariableWrapper): restore_vars = True veros_state._variables = DistSafeVariableWrapper(orig_vars, self.local_variables) veros_state._variables._gather_variables() execute = rst.proc_rank == 0 routine_ctx = enter_routine(name=self.name, routine_obj=self, timer=timer, dist_safe=self.dist_safe) out = None try: with routine_ctx: if execute: out = self.function(*args, **kwargs) finally: if restore_vars: veros_state._variables._scatter_variables() veros_state._variables = orig_vars flush() if out is not None: logger.warning( f"Routine {self.name} returned object of type {type(out)}. Return objects are silently dropped." ) def __get__(self, instance, _): return functools.partial(self.__call__, instance) def __repr__(self): return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>" # kernel
[docs]def veros_kernel(function=None, *, static_args=()): """Decorator that marks a function as a kernel that can be JIT compiled if supported by the backend. Kernels cannot modify the Veros state object. Instead, all modifications have to be returned explicitly. Parameters: static_args (Tuple[str]): Names of kernel arguments that should be static. Example: >>> from veros import veros_kernel, KernelOutput >>> >>> @veros_kernel >>> def double_psi(state): >>> vs = state.variables >>> vs.psi = 2 * vs.psi >>> return KernelOutput(psi=vs.psi) """ def inner_decorator(function): kernel = VerosKernel(function, static_args=static_args) kernel = functools.wraps(function)(kernel) return kernel if function is not None: return inner_decorator(function) return inner_decorator
class VerosKernel: """Do not instantiate directly!""" def __init__(self, function, static_args=()): """Do some parameter introspection.""" # make sure function signature is in the form we need self.name = _get_func_name(function) self.func_sig = inspect.signature(function) func_params = self.func_sig.parameters allowed_param_types = (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) if any(p.kind not in allowed_param_types for p in func_params.values()): raise ValueError(f"Veros kernels do not support *args, **kwargs, or keyword-only parameters ({self.name})") # parse static args if isinstance(static_args, str): static_args = (static_args,) func_argnames = list(func_params.keys()) self.static_argnums = [] for static_arg in static_args: try: arg_index = func_argnames.index(static_arg) except ValueError: raise ValueError( f'Veros kernel {self.name} has no argument "{static_arg}", but it is given in static_args' ) from None self.static_argnums.append(arg_index) self.function = function def __call__(self, *args, **kwargs): from veros import runtime_settings, runtime_state from veros.core.operators import flush inject_tokens = runtime_settings.backend == "jax" and runtime_state.proc_num > 1 # apply JIT if runtime_settings.backend == "jax": import jax CompiledFunction = type(jax.jit(lambda: None)) if not isinstance(self.function, CompiledFunction): if inject_tokens: function = self.function @functools.wraps(function) def token_wrapper(*args): inputs = args[:-1] token = args[-1] CURRENT_CONTEXT.mpi4jax_token = token out = function(*inputs) token = CURRENT_CONTEXT.mpi4jax_token return out, token if CURRENT_CONTEXT.mpi4jax_token is None: CURRENT_CONTEXT.mpi4jax_token = jax.lax.create_token() self.function = token_wrapper self.function = jax.jit(self.function, static_argnums=self.static_argnums) # JAX only accepts positional args when using static_argnums # so convert everything to positional for consistency bound_args = self.func_sig.bind(*args, **kwargs) bound_args.apply_defaults() veros_state = None for argval in bound_args.arguments.values(): if isinstance(argval, VerosState): veros_state = argval break called_with_state = veros_state is not None # when profiling, make sure all inputs are ready before starting the timer if runtime_settings.profile_mode: flush() if called_with_state: timer = veros_state.profile_timers[self.name] else: timer = None with ExitStack() as es: if called_with_state: es.enter_context(veros_state.variables.unlock()) args = list(bound_args.arguments.values()) if inject_tokens: args.append(CURRENT_CONTEXT.mpi4jax_token) with enter_routine(self.name, self, timer): out = self.function(*args) if runtime_settings.profile_mode: flush() if inject_tokens: out, token = out CURRENT_CONTEXT.mpi4jax_token = token return out def __repr__(self): return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>" def is_veros_routine(func): if isinstance(func, functools.partial): func = func.func if inspect.ismethod(func): func = func.__self__ return isinstance(func, VerosRoutine)