Source code for veros.diagnostics.diagnostic

import os

from loguru import logger

from .io_tools import netcdf as nctools, hdf5 as h5tools
from ..decorators import veros_method, do_not_disturb
from .. import time, runtime_state, distributed, runtime_settings


[docs]class VerosDiagnostic: """Base class for diagnostics. Provides an interface and wrappers for common I/O. Any diagnostic needs to implement the five interface methods and set some attributes. """ name = None #: Name that identifies the current diagnostic sampling_frequency = 0. output_frequency = 0. output_path = None def __init__(self, vs): pass def _not_implemented(self, vs): raise NotImplementedError('must be implemented by subclass') initialize = _not_implemented """Called at the end of setup. Use this to process user settings and handle setup.""" diagnose = _not_implemented """Called with frequency ``sampling_frequency``.""" output = _not_implemented """Called with frequency ``output_frequency``.""" write_restart = _not_implemented """Responsible for writing restart files.""" read_restart = _not_implemented """Responsible for reading restart files.""" @veros_method def get_output_file_name(self, vs): return self.output_path.format(**vars(vs)) @do_not_disturb @veros_method def initialize_output(self, vs, variables, var_data=None, extra_dimensions=None): if vs.diskless_mode or (not self.output_frequency and not self.sampling_frequency): return output_path = self.get_output_file_name(vs) if os.path.isfile(output_path) and not vs.force_overwrite: raise IOError('output file {} for diagnostic "{}" exists ' '(change output path or enable force_overwrite setting)' .format(output_path, self.name)) with nctools.threaded_io(vs, output_path, 'w') as outfile: nctools.initialize_file(vs, outfile) if extra_dimensions: for dim_id, size in extra_dimensions.items(): nctools.add_dimension(vs, dim_id, size, outfile) for key, var in variables.items(): if key not in outfile.variables: nctools.initialize_variable(vs, key, var, outfile) if not var.time_dependent: if var_data is None or key not in var_data: raise ValueError('var_data argument must be given for constant variables') nctools.write_variable(vs, key, var, var_data[key], outfile) @do_not_disturb @veros_method def write_output(self, vs, variables, variable_data): if vs.diskless_mode: return with nctools.threaded_io(vs, self.get_output_file_name(vs), 'r+') as outfile: time_step = nctools.get_current_timestep(vs, outfile) current_days = time.convert_time(vs.time, 'seconds', 'days') nctools.advance_time(vs, time_step, current_days, outfile) for key, var in variables.items(): nctools.write_variable(vs, key, var, variable_data[key], outfile, time_step=time_step) @veros_method def read_h5_restart(self, vs, var_meta, restart_filename): if not os.path.isfile(restart_filename): raise IOError('restart file {} not found'.format(restart_filename)) logger.info(' Reading restart data for diagnostic {} from {}', self.name, restart_filename) with h5tools.threaded_io(vs, restart_filename, 'r') as infile: variables = {} for key, var in infile[self.name].items(): if np.isscalar(var): variables[key] = var continue local_shape = distributed.get_local_size(vs, var.shape, var_meta[key].dims, include_overlap=True) gidx, lidx = distributed.get_chunk_slices(vs, var_meta[key].dims[:var.ndim], include_overlap=True) variables[key] = np.empty(local_shape, dtype=str(var.dtype)) if runtime_settings.backend == 'bohrium': variables[key][lidx] = var[gidx].astype(variables[key].dtype) else: variables[key][lidx] = var[gidx] distributed.exchange_overlap(vs, variables[key], var_meta[key].dims) attributes = {key: var.item() for key, var in infile[self.name].attrs.items()} return attributes, variables @do_not_disturb @veros_method def write_h5_restart(self, vs, attributes, var_meta, var_data, outfile): group = outfile.require_group(self.name) for key, var in var_data.items(): try: var = var.copy2numpy() except AttributeError: pass global_shape = distributed.get_global_size(vs, var.shape, var_meta[key].dims, include_overlap=True) gidx, lidx = distributed.get_chunk_slices(vs, var_meta[key].dims, include_overlap=True) kwargs = dict( exact=True, chunks=tuple( distributed.get_local_size(vs, var.shape, var_meta[key].dims, include_overlap=False) ) ) if vs.enable_hdf5_gzip_compression and runtime_state.proc_num == 1: kwargs.update( compression='gzip', compression_opts=9 ) group.require_dataset(key, global_shape, var.dtype, **kwargs) group[key][gidx] = var[lidx] for key, val in attributes.items(): try: val = val.copy2numpy() except AttributeError: pass group.attrs[key] = val