import functools
import inspect
import threading
from contextlib import ExitStack, contextmanager
from veros import logger
from veros.state import VerosState
# stack helpers
class RoutineStack:
def __init__(self):
self.keep_full_stack = False
self._stack = []
self._current_idx = []
@property
def stack_level(self):
return len(self._current_idx)
def append(self, val):
frame = self._stack
for i in self._current_idx:
frame = frame[i][1]
self._current_idx.append(len(frame))
frame.append([val, []])
def pop(self):
frame = self._stack
for i in self._current_idx[:-1]:
frame = frame[i][1]
if self.keep_full_stack:
last_val = frame[-1][0]
else:
last_val = frame.pop()[0]
self._current_idx.pop()
return last_val
# global context
CURRENT_CONTEXT = threading.local()
CURRENT_CONTEXT.is_dist_safe = True
CURRENT_CONTEXT.routine_stack = RoutineStack()
CURRENT_CONTEXT.mpi4jax_token = None
@contextmanager
def nullcontext():
yield
@contextmanager
def enter_routine(name, routine_obj, timer=None, dist_safe=True):
from veros import runtime_state as rst
from veros.distributed import abort
stack = CURRENT_CONTEXT.routine_stack
logger.trace("{}> {}", "-" * stack.stack_level, name)
stack.append(routine_obj)
reset_dist_safe = False
if CURRENT_CONTEXT.is_dist_safe:
if not dist_safe and rst.proc_num > 1:
CURRENT_CONTEXT.is_dist_safe = False
reset_dist_safe = True
timer_ctx = nullcontext() if timer is None else timer
try:
with timer_ctx:
yield
except: # noqa: E722
if reset_dist_safe:
abort()
raise
finally:
if reset_dist_safe:
CURRENT_CONTEXT.is_dist_safe = True
r = stack.pop()
assert r is routine_obj
exec_time = ""
if timer is not None:
exec_time = f"({timer.last_time:.3f}s)"
logger.trace("<{} {} {}", "-" * stack.stack_level, name, exec_time)
# helper functions
def _get_func_name(function):
return f"{inspect.getmodule(function).__name__}:{function.__qualname__}"
def _is_method(function):
if inspect.ismethod(function):
return True
# hack for unbound methods: check if first argument is called "self"
spec = inspect.getfullargspec(function)
return spec.args and spec.args[0] == "self"
# routine
[docs]
def veros_routine(function=None, *, dist_safe=True, local_variables=()):
"""
.. note::
This decorator should be applied to all functions that access the Veros state object
(even when subclassing :class:`veros.VerosSetup`).
The first argument to the decorated function must be a VerosState instance.
Veros routines cannot return anything. All changes must be applied to the passed state object.
Parameters:
dist_safe (bool): If set to False, all variables specified in local_variables are synced
to the root process before execution and synced back after. This means that the routine
will only be executed on rank 0. Has no effect in non-distributed contexts.
local_variables (Tuple[str]): List of variable names to be synced if dist_safe=False. This
must include all variables retrieved from the state object throughout the routine (inputs
*and* outputs).
Example:
>>> from veros import VerosSetup, veros_routine
>>>
>>> class MyModel(VerosSetup):
>>> @veros_routine
>>> def set_topography(self, state):
>>> vs = state.variables
>>> settings = state.settings
>>> vs.kbot = npx.random.randint(0, settings.nz, size=vs.kbot.shape)
"""
def inner_decorator(function):
narg = 1 if _is_method(function) else 0
num_params = len(inspect.signature(function).parameters)
if narg >= num_params:
raise TypeError("Veros routines must take at least one argument")
routine = VerosRoutine(function, state_argnum=narg, dist_safe=dist_safe, local_variables=local_variables)
routine = functools.wraps(function)(routine)
return routine
if function is not None:
return inner_decorator(function)
return inner_decorator
class VerosRoutine:
"""Do not instantiate directly!"""
def __init__(self, function, dist_safe=True, local_variables=(), state_argnum=0):
if isinstance(local_variables, str):
local_variables = (local_variables,)
self.function = function
self.dist_safe = dist_safe
self.local_variables = local_variables
self.state_argnum = state_argnum
self.name = _get_func_name(self.function)
def __call__(self, *args, **kwargs):
from veros import runtime_state as rst
from veros.state import VerosState, DistSafeVariableWrapper
from veros.core.operators import flush
veros_state = args[self.state_argnum]
if not isinstance(veros_state, VerosState):
raise TypeError(f"Argument {self.state_argnum} to this Veros routine must be a VerosState object")
timer = veros_state.profile_timers[self.name]
with ExitStack() as es:
vars_initialized = veros_state._variables is not None
if vars_initialized:
es.enter_context(veros_state.variables.unlock())
execute = True
restore_vars = False
if not self.dist_safe:
orig_vars = veros_state._variables
if not isinstance(orig_vars, DistSafeVariableWrapper):
restore_vars = True
veros_state._variables = DistSafeVariableWrapper(orig_vars, self.local_variables)
veros_state._variables._gather_variables()
execute = rst.proc_rank == 0
routine_ctx = enter_routine(name=self.name, routine_obj=self, timer=timer, dist_safe=self.dist_safe)
out = None
try:
with routine_ctx:
if execute:
out = self.function(*args, **kwargs)
finally:
if restore_vars:
veros_state._variables._scatter_variables()
veros_state._variables = orig_vars
flush()
if out is not None:
logger.warning(
f"Routine {self.name} returned object of type {type(out)}. Return objects are silently dropped."
)
def __get__(self, instance, _):
return functools.partial(self.__call__, instance)
def __repr__(self):
return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>"
# kernel
[docs]
def veros_kernel(function=None, *, static_args=()):
"""Decorator that marks a function as a kernel that can be JIT compiled if supported
by the backend.
Kernels cannot modify the Veros state object. Instead, all modifications have to be
returned explicitly.
Parameters:
static_args (Tuple[str]): Names of kernel arguments that should be static.
Example:
>>> from veros import veros_kernel, KernelOutput
>>>
>>> @veros_kernel
>>> def double_psi(state):
>>> vs = state.variables
>>> vs.psi = 2 * vs.psi
>>> return KernelOutput(psi=vs.psi)
"""
def inner_decorator(function):
kernel = VerosKernel(function, static_args=static_args)
kernel = functools.wraps(function)(kernel)
return kernel
if function is not None:
return inner_decorator(function)
return inner_decorator
class VerosKernel:
"""Do not instantiate directly!"""
def __init__(self, function, static_args=()):
"""Do some parameter introspection."""
# make sure function signature is in the form we need
self.name = _get_func_name(function)
self.func_sig = inspect.signature(function)
func_params = self.func_sig.parameters
allowed_param_types = (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
if any(p.kind not in allowed_param_types for p in func_params.values()):
raise ValueError(f"Veros kernels do not support *args, **kwargs, or keyword-only parameters ({self.name})")
# parse static args
if isinstance(static_args, str):
static_args = (static_args,)
func_argnames = list(func_params.keys())
self.static_argnums = []
for static_arg in static_args:
try:
arg_index = func_argnames.index(static_arg)
except ValueError:
raise ValueError(
f'Veros kernel {self.name} has no argument "{static_arg}", but it is given in static_args'
) from None
self.static_argnums.append(arg_index)
self.function = function
def __call__(self, *args, **kwargs):
from veros import runtime_settings, runtime_state
from veros.core.operators import flush
inject_tokens = runtime_settings.backend == "jax" and runtime_state.proc_num > 1
# apply JIT
if runtime_settings.backend == "jax":
import jax
CompiledFunction = type(jax.jit(lambda: None))
if not isinstance(self.function, CompiledFunction):
if inject_tokens:
function = self.function
@functools.wraps(function)
def token_wrapper(*args):
inputs = args[:-1]
token = args[-1]
CURRENT_CONTEXT.mpi4jax_token = token
out = function(*inputs)
token = CURRENT_CONTEXT.mpi4jax_token
return out, token
if CURRENT_CONTEXT.mpi4jax_token is None:
CURRENT_CONTEXT.mpi4jax_token = jax.lax.create_token()
self.function = token_wrapper
self.function = jax.jit(self.function, static_argnums=self.static_argnums)
# JAX only accepts positional args when using static_argnums
# so convert everything to positional for consistency
bound_args = self.func_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
veros_state = None
for argval in bound_args.arguments.values():
if isinstance(argval, VerosState):
veros_state = argval
break
called_with_state = veros_state is not None
# when profiling, make sure all inputs are ready before starting the timer
if runtime_settings.profile_mode:
flush()
if called_with_state:
timer = veros_state.profile_timers[self.name]
else:
timer = None
with ExitStack() as es:
if called_with_state:
es.enter_context(veros_state.variables.unlock())
args = list(bound_args.arguments.values())
if inject_tokens:
args.append(CURRENT_CONTEXT.mpi4jax_token)
with enter_routine(self.name, self, timer):
out = self.function(*args)
if runtime_settings.profile_mode:
flush()
if inject_tokens:
out, token = out
CURRENT_CONTEXT.mpi4jax_token = token
return out
def __repr__(self):
return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>"
def is_veros_routine(func):
if isinstance(func, functools.partial):
func = func.func
if inspect.ismethod(func):
func = func.__self__
return isinstance(func, VerosRoutine)