Source code for veros.diagnostics.overturning

from collections import OrderedDict
import os

from loguru import logger

from .. import veros_method
from .diagnostic import VerosDiagnostic
from ..core import density
from ..variables import Variable, allocate
from ..distributed import global_sum


SIGMA = Variable(
    'Sigma axis', ('sigma',), 'kg/m^3', 'Sigma axis', output=True,
    time_dependent=False, write_to_restart=True
)

OVERTURNING_VARIABLES = OrderedDict([
    ('trans', Variable(
        'Meridional transport', ('yu', 'sigma'), 'm^3/s',
        'Meridional transport', output=True, write_to_restart=True
    )),
    ('vsf_iso', Variable(
        'Meridional transport', ('yu', 'zw'), 'm^3/s',
        'Meridional transport', output=True, write_to_restart=True
    )),
    ('vsf_depth', Variable(
        'Meridional transport', ('yu', 'zw'), 'm^3/s',
        'Meridional transport', output=True, write_to_restart=True
    )),
])
ISONEUTRAL_VARIABLES = OrderedDict([
    ('bolus_iso', Variable(
        'Meridional transport', ('yu', 'zw'), 'm^3/s',
        'Meridional transport', output=True, write_to_restart=True
    )),
    ('bolus_depth', Variable(
        'Meridional transport', ('yu', 'zw'), 'm^3/s',
        'Meridional transport', output=True, write_to_restart=True
    )),
])


@veros_method(inline=True)
def zonal_sum(vs, arr):
    return global_sum(vs, np.sum(arr, axis=0), axis=0)


[docs]class Overturning(VerosDiagnostic): """Isopycnal overturning diagnostic. Computes and writes vertical streamfunctions (zonally averaged). """ name = 'overturning' #: output_path = '{identifier}.overturning.nc' #: File to write to. May contain format strings that are replaced with Veros attributes. output_frequency = None #: Frequency (in seconds) in which output is written. sampling_frequency = None #: Frequency (in seconds) in which variables are accumulated. p_ref = 2000. #: Reference pressure for isopycnals def __init__(self, vs): self.sigma_var = SIGMA self.mean_variables = OVERTURNING_VARIABLES if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: self.mean_variables.update(ISONEUTRAL_VARIABLES) self.variables = self.mean_variables.copy() self.variables.update({'sigma': self.sigma_var}) @veros_method def initialize(self, vs): self.nitts = 0 self.nlevel = vs.nz * 4 self._allocate(vs) # sigma levels self.sige = density.get_potential_rho(vs, 35., -2., press_ref=self.p_ref) self.sigs = density.get_potential_rho(vs, 35., 30., press_ref=self.p_ref) self.dsig = float(self.sige - self.sigs) / (self.nlevel - 1) logger.debug(' Sigma ranges for overturning diagnostic:') logger.debug(' Start sigma0 = {:.1f}'.format(self.sigs)) logger.debug(' End sigma0 = {:.1f}'.format(self.sige)) logger.debug(' Delta sigma0 = {:.1e}'.format(self.dsig)) if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: logger.debug(' Also calculating overturning by eddy-driven velocities') self.sigma[...] = self.sigs + self.dsig * np.arange(self.nlevel) # precalculate area below z levels self.zarea[2:-2, :] = np.cumsum(zonal_sum(vs, vs.dxt[2:-2, np.newaxis, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.maskV[2:-2, 2:-2, :]) * vs.dzt[np.newaxis, :], axis=1) self.initialize_output(vs, self.variables, var_data={'sigma': self.sigma}, extra_dimensions={'sigma': self.nlevel}) @veros_method def _allocate(self, vs): self.sigma = allocate(vs, (self.nlevel,)) self.zarea = allocate(vs, ('yu', 'zt')) self.trans = allocate(vs, ('yu', self.nlevel)) self.vsf_iso = allocate(vs, ('yu', 'zt')) self.vsf_depth = allocate(vs, ('yu', 'zt')) if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: self.bolus_iso = allocate(vs, ('yu', 'zt')) self.bolus_depth = allocate(vs, ('yu', 'zt')) @veros_method def diagnose(self, vs): # sigma at p_ref sig_loc = allocate(vs, ('xt', 'yt', 'zt')) sig_loc[2:-2, 2:-1, :] = density.get_rho(vs, vs.salt[2:-2, 2:-1, :, vs.tau], vs.temp[2:-2, 2:-1, :, vs.tau], self.p_ref) # transports below isopycnals and area below isopycnals sig_loc_face = 0.5 * (sig_loc[2:-2, 2:-2, :] + sig_loc[2:-2, 3:-1, :]) trans = allocate(vs, ('yu', self.nlevel)) z_sig = allocate(vs, ('yu', self.nlevel)) fac = (vs.dxt[2:-2, np.newaxis, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.dzt[np.newaxis, np.newaxis, :] * vs.maskV[2:-2, 2:-2, :]) for m in range(self.nlevel): # NOTE: vectorized version would be O(N^4) in memory # consider cythonizing if performance-critical mask = sig_loc_face > self.sigma[m] trans[2:-2, m] = zonal_sum(vs, np.sum(vs.v[2:-2, 2:-2, :, vs.tau] * fac * mask, axis=2)) z_sig[2:-2, m] = zonal_sum(vs, np.sum(fac * mask, axis=2)) self.trans += trans if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: bolus_trans = allocate(vs, ('yu', self.nlevel)) # eddy-driven transports below isopycnals for m in range(self.nlevel): # NOTE: see above mask = sig_loc_face > self.sigma[m] bolus_trans[2:-2, m] = zonal_sum(vs, np.sum( (vs.B1_gm[2:-2, 2:-2, 1:] - vs.B1_gm[2:-2, 2:-2, :-1]) * vs.dxt[2:-2, np.newaxis, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.maskV[2:-2, 2:-2, 1:] * mask[:, :, 1:], axis=2 ) + vs.B1_gm[2:-2, 2:-2, 0] * vs.dxt[2:-2, np.newaxis] * vs.cosu[np.newaxis, 2:-2] * vs.maskV[2:-2, 2:-2, 0] * mask[:, :, 0] ) # streamfunction on geopotentials self.vsf_depth[2:-2, :] += np.cumsum(zonal_sum(vs, vs.dxt[2:-2, np.newaxis, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.v[2:-2, 2:-2, :, vs.tau] * vs.maskV[2:-2, 2:-2, :]) * vs.dzt[np.newaxis, :], axis=1) if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: # streamfunction for eddy driven velocity on geopotentials self.bolus_depth[2:-2, :] += zonal_sum(vs, vs.dxt[2:-2, np.newaxis, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.B1_gm[2:-2, 2:-2, :]) # interpolate from isopycnals to depth self.vsf_iso[2:-2, :] += self._interpolate_along_axis(vs, z_sig[2:-2, :], trans[2:-2, :], self.zarea[2:-2, :], 1) if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: self.bolus_iso[2:-2, :] += self._interpolate_along_axis(vs, z_sig[2:-2, :], bolus_trans[2:-2, :], self.zarea[2:-2, :], 1) self.nitts += 1 @veros_method def _interpolate_along_axis(self, vs, coords, arr, interp_coords, axis=0): # TODO: clean up this mess if coords.ndim == 1: if len(coords) != arr.shape[axis]: raise ValueError('Coordinate shape must match array shape along axis') elif coords.ndim == arr.ndim: if coords.shape != arr.shape: raise ValueError('Coordinate shape must match array shape') else: raise ValueError('Coordinate shape must match array dimensions') if axis != 0: arr = np.moveaxis(arr, axis, 0) coords = np.moveaxis(coords, axis, 0) interp_coords = np.moveaxis(interp_coords, axis, 0) diff = coords[np.newaxis, :, ...] - interp_coords[:, np.newaxis, ...] diff_m = np.where(diff <= 0., np.abs(diff), np.inf) diff_p = np.where(diff > 0., np.abs(diff), np.inf) i_m = np.asarray(np.argmin(diff_m, axis=1)) i_p = np.asarray(np.argmin(diff_p, axis=1)) mask = np.all(np.isinf(diff_m), axis=1) i_m[mask] = i_p[mask] mask = np.all(np.isinf(diff_p), axis=1) i_p[mask] = i_m[mask] full_shape = (slice(None),) + (np.newaxis,) * (arr.ndim - 1) if coords.ndim == 1: i_p_full = i_p[full_shape] * np.ones(arr.shape) i_m_full = i_m[full_shape] * np.ones(arr.shape) else: i_p_full = i_p i_m_full = i_m ii = np.indices(i_p_full.shape) i_p_slice = (i_p_full,) + tuple(ii[1:]) i_m_slice = (i_m_full,) + tuple(ii[1:]) dx = (coords[i_p_slice] - coords[i_m_slice]) pos = np.where(dx == 0., 0., (coords[i_p_slice] - interp_coords) / (dx + 1e-12)) return np.moveaxis(arr[i_p_slice] * (1. - pos) + arr[i_m_slice] * pos, 0, axis) @veros_method def output(self, vs): if not os.path.isfile(self.get_output_file_name(vs)): self.initialize_output(vs, self.variables, var_data={'sigma': self.sigma}, extra_dimensions={'sigma': self.nlevel}) if self.nitts > 0: for var in self.mean_variables.keys(): getattr(self, var)[...] *= 1. / self.nitts var_metadata = {key: var for key, var in self.mean_variables.items() if var.output and var.time_dependent} var_data = {key: getattr(self, key) for key, var in self.mean_variables.items() if var.output and var.time_dependent} self.write_output(vs, var_metadata, var_data) self.nitts = 0 for var in self.mean_variables.keys(): getattr(self, var)[...] = 0. @veros_method def read_restart(self, vs, infile): attributes, variables = self.read_h5_restart(vs, self.variables, infile) if attributes: self.nitts = attributes['nitts'] if variables: for var, arr in variables.items(): getattr(self, var)[...] = arr @veros_method def write_restart(self, vs, outfile): var_data = {key: getattr(self, key) for key, var in self.variables.items() if var.write_to_restart} self.write_h5_restart(vs, {'nitts': self.nitts}, self.variables, var_data, outfile)