Source code for veros.decorators

import functools
import signal
import inspect
import threading

from loguru import logger

try:
    getargspec = inspect.getfullargspec
except AttributeError:  # python 2
    getargspec = inspect.getargspec


CONTEXT = threading.local()
CONTEXT.is_dist_safe = True
CONTEXT.stack_level = 0


[docs]def veros_method(function=None, **kwargs): """Decorator that injects the current backend as variable ``np`` into the wrapped function. .. note:: This decorator should be applied to all functions that make use of the computational backend (even when subclassing :class:`veros.Veros`). The first argument to the decorated function must be a Veros instance. Example: >>> from veros import Veros, veros_method >>> >>> class MyModel(Veros): >>> @veros_method >>> def set_topography(self): >>> self.kbot[...] = np.random.randint(0, self.nz, size=self.kbot.shape) """ if function is not None: narg = 1 if _is_method(function) else 0 return _veros_method(function, narg=narg) inline = kwargs.pop('inline', False) dist_safe = kwargs.pop('dist_safe', True) if not dist_safe and 'local_variables' not in kwargs: raise ValueError('local_variables argument must be given if dist_safe=False') local_vars = kwargs.pop('local_variables', []) dist_only = kwargs.pop('dist_only', False) def inner_decorator(function): narg = 1 if _is_method(function) else 0 return _veros_method( function, inline=inline, narg=narg, dist_safe=dist_safe, local_vars=local_vars, dist_only=dist_only ) return inner_decorator
def _is_method(function): spec = getargspec(function) return spec.args and spec.args[0] == 'self' def _veros_method(function, inline=False, dist_safe=True, local_vars=None, dist_only=False, narg=0): @functools.wraps(function) def veros_method_wrapper(*args, **kwargs): from . import runtime_settings as rs, runtime_state as rst from .backend import flush, get_backend from .state import VerosStateBase from .state_dist import DistributedVerosState from .distributed import broadcast if not inline: logger.trace( '{}> {}:{}', '-' * CONTEXT.stack_level, inspect.getmodule(function).__name__, function.__name__ ) CONTEXT.stack_level += 1 veros_state = args[narg] if not isinstance(veros_state, VerosStateBase): raise TypeError('first argument to a veros_method must be a veros state object') reset_dist_safe = False if not CONTEXT.is_dist_safe: assert isinstance(veros_state, DistributedVerosState) elif not dist_safe and rst.proc_num > 1: reset_dist_safe = True if reset_dist_safe: dist_state = DistributedVerosState(veros_state) dist_state.gather_arrays(local_vars) func_state = dist_state CONTEXT.is_dist_safe = False else: func_state = veros_state execute = True if not CONTEXT.is_dist_safe: execute = rst.proc_rank == 0 g = function.__globals__ sentinel = object() oldvalue = g.get('np', sentinel) g['np'] = get_backend(rs.backend) newargs = list(args) newargs[narg] = func_state res = None try: if execute: res = function(*newargs, **kwargs) except: if reset_dist_safe: CONTEXT.is_dist_safe = True raise else: if reset_dist_safe: CONTEXT.is_dist_safe = True res = broadcast(veros_state, res) dist_state.scatter_arrays() finally: if oldvalue is sentinel: del g['np'] else: g['np'] = oldvalue if not inline: CONTEXT.stack_level -= 1 flush() return res return veros_method_wrapper def dist_context_only(function): @functools.wraps(function) def dist_context_only_wrapper(*args, **kwargs): # args are assumed to be (), (vs,), or (vs, arr, ...) from . import runtime_state as rst if rst.proc_num == 1 or not CONTEXT.is_dist_safe: # no-op for sequential execution try: # return input array unchanged return args[1] except IndexError: return return function(*args, **kwargs) return dist_context_only_wrapper def do_not_disturb(function): """Decorator that catches SIGINT and SIGTERM signals (e.g. after keyboard interrupt) and makes sure that the function body is executed before exiting. Useful e.g. for ensuring that output files are written properly. """ signals = (signal.SIGINT, signal.SIGTERM) @functools.wraps(function) def dnd_wrapper(*args, **kwargs): old_handlers = {s: signal.getsignal(s) for s in signals} signal_received = {'sig': None, 'frame': None} def handler(sig, frame): if signal_received['sig'] is None: signal_received['sig'] = sig signal_received['frame'] = frame logger.error('Signal {} received - cleaning up before exit', sig) else: # force quit if more than one signal is received old_handlers[sig](sig, frame) for s in signals: signal.signal(s, handler) try: res = function(*args, **kwargs) finally: for s in signals: signal.signal(s, old_handlers[s]) sig = signal_received['sig'] if sig is not None: old_handlers[sig](signal_received['sig'], signal_received['frame']) return res return dnd_wrapper