import contextlib
from collections import defaultdict, namedtuple
from import Mapping
from copy import deepcopy

from veros import (
    settings as settings_mod,
    variables as var_mod,
    runtime_settings as rs,
    runtime_state as rst,

def make_namedtuple(**kwargs):
    return namedtuple("KernelOutput", list(kwargs.keys()))(*kwargs.values())

KernelOutput = make_namedtuple

class StrictContainer:
    """A mutable container with fixed fields (optionally typed)."""

    __fields__ = ()
    __field_types__ = ()

    def __init__(self, fields, *args, field_types=None, default=None, **kwargs):
        self.__fields__ = fields

        if field_types is None:
            self.__field_types__ = {}
            if not isinstance(field_types, dict) or not set(field_types.keys()) <= set(fields):
                raise ValueError("field_types must be a dict with fields as keys")

            self.__field_types__ = field_types

        for k in fields:
            if k in vars(self):
                raise ValueError(f"Name collision: {k}")

            if k.startswith("_"):
                raise ValueError(f"Fields cannot start with _ (got: {k}).")

            super().__setattr__(k, default)

    def __setattr__(self, key, val):
        if not key.startswith("_") and key not in self.__fields__:
            raise AttributeError(f"Unknown attribute {key}")

        if key in self.__field_types__:
            val = self.__field_types__[key](val)

        return super().__setattr__(key, val)

    def __contains__(self, val):
        return val in self.__fields__

    def fields(self):
        return self.__fields__

    def values(self):
        return (getattr(self, k) for k in self.__fields__)

    def items(self):
        return ((k, getattr(self, k)) for k in self.__fields__)

    def update(self, other=None, **new_fields):
        if other is not None:
            if new_fields:
                raise ValueError("Either other or new_fields can be given")

            if hasattr(other, "_fields"):
                # other is namedtuple
                new_fields = dict(zip(other._fields, other))
            elif isinstance(other, (dict, StrictContainer)):
                new_fields = other
                raise TypeError(f"Cannot update from {type(other)} type")

        for key, val in new_fields.items():
            if key not in self.__fields__:
                raise AttributeError(f"unknown attribute {key}")

        for key, val in new_fields.items():
            setattr(self, key, val)

        return self

    def get(self, key, default=None):
        return getattr(self, key, default)

    def __repr__(self):
        attr_str = []

        for key, val in self.items():
            # poor-man's check for array-compatible types
            if hasattr(val, "shape") and hasattr(val, "dtype"):
                val_repr = f"{type(val)} with shape {val.shape}, dtype {val.dtype}"
                val_repr = repr(val)

            attr_str.append(f"    {key} = {val_repr}")

        attr_str = ",\n".join(attr_str)

        return f"{self.__class__.__qualname__}(\n{attr_str}\n)"

class Lockable:
    __locked__ = True

    def unlock(self):
        lock_state = self.__locked__
            self.__locked__ = False
            self.__locked__ = lock_state

    def __setattr__(self, key, val):
        if not key.startswith("_") and self.__locked__:
            clsname = self.__class__.__qualname__
            raise RuntimeError(
                f"{clsname} is locked to modifications. If you know what you are doing, "
                f'you can unlock it via the "{clsname}.unlock()" context manager.'
        return super().__setattr__(key, val)

class StaticDictProxy(Mapping):
    def __init__(self, content, writeback=None):
        self._wrapped = content
        self._writeback = writeback

    def __len__(self):
        return self._wrapped.__len__()

    def __iter__(self):
        return self._wrapped.__iter__()

    def __getitem__(self, key):
        return self._wrapped.__getitem__(key)

    def __setitem__(self, key, val):
        if key in self:
            raise RuntimeError("Cannot overwrite existing values")

        if self._writeback is not None:
            self._writeback.__setitem__(key, val)

        self._wrapped.__setitem__(key, val)

    def __repr__(self):
        return f"{self.__class__.__qualname__}({self._wrapped!r})"

[docs]class VerosSettings(Lockable, StrictContainer): def __init__(self, settings_meta): self.__metadata__ = settings_meta super().__init__(fields=settings_meta.keys()) default_settings = {k: meta.type(meta.default) for k, meta in settings_meta.items()} with self.unlock(): self.update(default_settings) def __setattr__(self, key, val): if key.startswith("_") or key not in self.__metadata__: return super().__setattr__(key, val) meta = self.__metadata__[key] val = meta.type(val) return super().__setattr__(key, val)
[docs]class VerosVariables(Lockable, StrictContainer): """ """ def __init__(self, var_meta, dimensions): self.__metadata__ = var_meta self.__dimensions__ = dimensions active_vars = [key for key, val in var_meta.items() if] super().__init__(fields=active_vars) with self.unlock(): for key, val in var_meta.items(): if not continue allocate_kwargs = dict(dtype=val.dtype) if val.initial is not None: allocate_kwargs.update(fill=val.initial) setattr(self, key, var_mod.allocate(dimensions, val.dims, **allocate_kwargs)) def __getattr__(self, attr): orig_getattr = super().__getattribute__ try: var = orig_getattr("__metadata__")[attr] except (KeyError, AttributeError): return orig_getattr(attr) if not raise RuntimeError( f"Variable {attr} is not active in this configuration. Check your settings and try again." ) return orig_getattr(attr) def __setattr__(self, key, val): if key.startswith("_") or key not in self.__metadata__: return super().__setattr__(key, val) var = self.__metadata__[key] # check whether variable is active if not raise RuntimeError( f"Variable {key} is not active in this configuration. Check your settings and try again." ) # validate array type, shape and dtype if var.dtype is not None: expected_dtype = var.dtype else: expected_dtype = rs.float_type val = rst.backend_module.asarray(val, dtype=expected_dtype) expected_shape = self._get_expected_shape(var.dims) if val.shape != expected_shape: raise ValueError(f"Got unexpected shape for variable {key} (expected: {expected_shape}, got: {val.shape})") return super().__setattr__(key, val) def _get_expected_shape(self, dims): return var_mod.get_shape(self.__dimensions__, dims)
class DistSafeVariableWrapper(VerosVariables): def __init__(self, parent_state, local_variables): # set internal attributes to be identical to given variables object for attr, val in vars(parent_state).items(): if not attr.startswith("__"): continue super().__setattr__(attr, val) self.__parent_state__ = parent_state self.__local_variables__ = local_variables def __getattr__(self, attr): orig_getattr = super().__getattribute__ if attr in orig_getattr("__metadata__") and attr not in orig_getattr("__local_variables__"): raise RuntimeError( f"Cannot access variable {attr} because it was not collected. " "Consider adding it to the local_variables argument of @veros_routine." ) return orig_getattr(attr) def __setattr__(self, attr, val): if attr.startswith("_"): return super().__setattr__(attr, val) if attr in self.__metadata__ and attr not in self.__local_variables__: raise RuntimeError( f"Cannot access variable {attr} because it was not collected. " "Consider adding it to the local_variables argument of @veros_routine." ) return super().__setattr__(attr, val) def _gather_variables(self): from veros.distributed import gather var_meta = self.__metadata__ for var in self.__local_variables__: if var not in var_meta: raise ValueError(f"encountered unknown variable {var} in local variables") if not var_meta[var].active: continue gathered_var = gather(getattr(self.__parent_state__, var), self.__dimensions__, self.__metadata__[var].dims) setattr(self, var, gathered_var) def _scatter_variables(self): from veros.distributed import scatter, barrier barrier() var_meta = self.__metadata__ for var in self.__local_variables__: if var not in var_meta: raise ValueError(f"encountered unknown variable {var} in local variables") if not var_meta[var].active: continue scattered_var = scatter(getattr(self, var), self.__dimensions__, self.__metadata__[var].dims) setattr(self.__parent_state__, var, scattered_var) def _get_expected_shape(self, dims): return var_mod.get_shape(self.__dimensions__, dims, local=rst.proc_rank != 0) def __repr__(self): return f"{self.__class__.__qualname__}(parent_state={self.__parent_state__}, local_variables={self.__local_variables__})"
[docs]class VerosState: """Holds all settings and model state for a given Veros run.""" def __init__(self, var_meta, setting_meta, dimensions, diagnostics=None, plugin_interfaces=None): self._var_meta = var_meta self._variables = None self._settings = VerosSettings(setting_meta) self._dimensions = dimensions if diagnostics is not None: self._diagnostics = diagnostics else: self._diagnostics = {} if plugin_interfaces is not None: self._plugin_interfaces = plugin_interfaces else: self._plugin_interfaces = () timer_factory = timer.Timer self.timers = defaultdict(timer_factory) self.profile_timers = defaultdict(timer_factory) def __repr__(self): from textwrap import indent attr_str = [] for attr in ("settings", "dimensions", "variables", "diagnostics", "plugin_interfaces"): # indent all lines of attr repr except the first attr_val = indent(repr(getattr(self, f"_{attr}")), " " * 4)[4:] attr_str.append(f" {attr} = {attr_val}") attr_str = ",\n".join(attr_str) return f"{self.__class__.__qualname__}(\n{attr_str}\n)"
[docs] def initialize_variables(self): if self._variables is not None: raise RuntimeError("Variables are already initialized.") self._var_meta = var_mod.manifest_metadata(self._var_meta, self._settings) self._variables = VerosVariables(self._var_meta, self._manifest_dimensions())
@property def var_meta(self): return self._var_meta @property def variables(self): if self._variables is None: raise RuntimeError("Variables have not been initialized yet.") return self._variables @property def settings(self): return self._settings def _manifest_dimensions(self): concrete_dimensions = {} for dim_name, dim_target in self._dimensions.items(): if isinstance(dim_target, str): dim_size = getattr(self._settings, dim_target) else: dim_size = dim_target concrete_dimensions[dim_name] = int(dim_size) return concrete_dimensions @property def dimensions(self): concrete_dimensions = self._manifest_dimensions() return StaticDictProxy(concrete_dimensions, self._dimensions) @property def diagnostics(self): return self._diagnostics @property def plugin_interfaces(self): return self._plugin_interfaces
[docs] def to_xarray(self): import xarray as xr vs = self.variables coords = {} data_vars = {} for var_name, var_meta in self.var_meta.items(): if not continue data = var_mod.remove_ghosts(vs.get(var_name), var_meta.dims) data_vars[var_name] = xr.DataArray( data, dims=var_meta.dims, name=var_name, attrs=dict( long_description=var_meta.long_description, units=var_meta.units, scale=var_meta.scale, ), ) if var_meta.dims is None: continue for dim in var_meta.dims: if dim not in coords: coords[dim] = range(var_mod.get_shape(self.dimensions, (dim,), include_ghosts=False)[0]) data_vars = {k: v for k, v in data_vars.items() if k not in coords} attrs = dict(self.settings.items()) return xr.Dataset(data_vars, coords=coords, attrs=attrs)
def get_default_state(use_plugins=None): if use_plugins is not None: plugin_interfaces = tuple(plugins.load_plugin(p) for p in use_plugins) else: plugin_interfaces = tuple() default_settings = deepcopy(settings_mod.SETTINGS) for plugin in plugin_interfaces: default_settings.update(plugin.settings) default_dimensions = deepcopy(var_mod.DIM_TO_SHAPE_VAR) var_meta = deepcopy(var_mod.VARIABLES) for plugin in plugin_interfaces: var_meta.update(plugin.variables) return VerosState(var_meta, default_settings, default_dimensions, plugin_interfaces=plugin_interfaces) def veros_state_pytree_flatten(state): aux_data = tuple((k, v) for k, v in vars(state).items() if k != "_variables") # ensure that functions are re-traced when settings change with state.settings.unlock(): pseudo_hash = hash(tuple(state.settings.items())) return ([state.variables], (aux_data, pseudo_hash)) def veros_state_pytree_unflatten(aux_data, leaves): assert len(leaves) == 1 variables = leaves[0] # by-pass __init__ and set attributes manually state = VerosState.__new__(VerosState) state._variables = variables state_attrs, _ = aux_data for attr, val in state_attrs: setattr(state, attr, val) return state def veros_variables_pytree_flatten(variables): aux_attrs = ( "__dimensions__", "__metadata__", "__fields__", "__locked__", ) leaves = list(variables.values()) aux_data = (tuple(variables.fields()), tuple((attr, getattr(variables, attr)) for attr in aux_attrs)) return (leaves, aux_data) def veros_variables_pytree_unflatten(aux_data, leaves): keys, aux_attrs = aux_data # by-pass __init__ and set attributes manually variables = VerosVariables.__new__(VerosVariables) for key, val in aux_attrs: setattr(variables, key, val) with variables.unlock(): for key, val in zip(keys, leaves): setattr(variables, key, val) return variables def dist_safe_wrapper_pytree_flatten(variables): aux_attrs = ( "__dimensions__", "__metadata__", "__fields__", "__locked__", "__local_variables__", "__parent_state__", ) with variables.unlock(): leaves = [getattr(variables, attr) for attr in variables.__local_variables__] aux_data = (tuple(variables.__local_variables__), tuple((attr, getattr(variables, attr)) for attr in aux_attrs)) return (leaves, aux_data) def dist_safe_wrapper_pytree_unflatten(aux_data, leaves): keys, aux_attrs = aux_data # by-pass __init__ and set attributes manually variables = DistSafeVariableWrapper.__new__(DistSafeVariableWrapper) for key, val in aux_attrs: setattr(variables, key, val) with variables.unlock(): for key, val in zip(keys, leaves): setattr(variables, key, val) return variables def resize_dimension(state, dimension, new_size): """Resize a dimension of an existing VerosState object. This re-allocates all variables using the dimension to 0. """ state._dimensions[dimension] = new_size state.variables.__dimensions__[dimension] = new_size with state.variables.unlock(): for var in state.variables.fields(): var_meta = state.variables.__metadata__[var] var_dims = var_meta.dims if var_dims is None or dimension not in var_dims: continue setattr(state.variables, var, var_mod.allocate(state.dimensions, var_meta.dims, dtype=var_meta.dtype))