import abc
# do not import veros.core here!
from veros import settings, time, signals, distributed, progress, runtime_settings as rs, logger
from veros.state import get_default_state
from veros.plugins import load_plugin
from veros.routines import veros_routine, is_veros_routine
from veros.timer import timer_context
[docs]
class VerosSetup(metaclass=abc.ABCMeta):
"""Main class for Veros, used for building a model and running it.
Note:
This class is meant to be subclassed. Subclasses need to implement the
methods :meth:`set_parameter`, :meth:`set_topography`, :meth:`set_grid`,
:meth:`set_coriolis`, :meth:`set_initial_conditions`, :meth:`set_forcing`,
:meth:`set_diagnostics`, and :meth:`after_timestep`.
Example:
>>> import matplotlib.pyplot as plt
>>> from veros import VerosSetup
>>>
>>> class MyModel(VerosSetup):
>>> ...
>>>
>>> simulation = MyModel()
>>> simulation.run()
>>> plt.imshow(simulation.state.variables.psi[..., 0])
>>> plt.show()
"""
__veros_plugins__ = tuple()
def __init__(self, override=None):
self.override_settings = override or {}
# this should be the first time the core routines are imported
import veros.core # noqa: F401
self._plugin_interfaces = tuple(load_plugin(p) for p in self.__veros_plugins__)
self._setup_done = False
self.state = get_default_state(plugin_interfaces=self._plugin_interfaces)
[docs]
@abc.abstractmethod
def set_parameter(self, state):
"""To be implemented by subclass.
First function to be called during setup.
Use this to modify the model settings.
Example:
>>> def set_parameter(self, state):
>>> settings = state.settings
>>> settings.nx, settings.ny, settings.nz = (360, 120, 50)
>>> settings.coord_degree = True
>>> settings.enable_cyclic = True
"""
pass
[docs]
@abc.abstractmethod
def set_initial_conditions(self, state):
"""To be implemented by subclass.
May be used to set initial conditions.
Example:
>>> @veros_method
>>> def set_initial_conditions(self, state):
>>> vs = state.variables
>>> vs.u = update(vs.u, at[:, :, :, vs.tau], npx.random.rand(vs.u.shape[:-1]))
"""
pass
[docs]
@abc.abstractmethod
def set_grid(self, state):
"""To be implemented by subclass.
Has to set the grid spacings :attr:`dxt`, :attr:`dyt`, and :attr:`dzt`,
along with the coordinates of the grid origin, :attr:`x_origin` and
:attr:`y_origin`.
Example:
>>> @veros_method
>>> def set_grid(self, state):
>>> vs = state.variables
>>> vs.x_origin, vs.y_origin = 0, 0
>>> vs.dxt = [0.1, 0.05, 0.025, 0.025, 0.05, 0.1]
>>> vs.dyt = 1.
>>> vs.dzt = [10, 10, 20, 50, 100, 200]
"""
pass
[docs]
@abc.abstractmethod
def set_coriolis(self, state):
"""To be implemented by subclass.
Has to set the Coriolis parameter :attr:`coriolis_t` at T grid cells.
Example:
>>> @veros_method
>>> def set_coriolis(self, state):
>>> vs = state.variables
>>> vs.coriolis_t = 2 * vs.omega * npx.sin(vs.yt[npx.newaxis, :] / 180. * vs.pi)
"""
pass
[docs]
@abc.abstractmethod
def set_topography(self, state):
"""To be implemented by subclass.
Must specify the model topography by setting :attr:`kbot`.
Example:
>>> @veros_method
>>> def set_topography(self, state):
>>> vs = state.variables
>>> vs.kbot = update(vs.kbot, at[...], 10)
>>> # add a rectangular island somewhere inside the domain
>>> vs.kbot = update(vs.kbot, at[10:20, 10:20], 0)
"""
pass
[docs]
@abc.abstractmethod
def set_forcing(self, state):
"""To be implemented by subclass.
Called before every time step to update the external forcing, e.g. through
:attr:`forc_temp_surface`, :attr:`forc_salt_surface`, :attr:`surface_taux`,
:attr:`surface_tauy`, :attr:`forc_tke_surface`, :attr:`temp_source`, or
:attr:`salt_source`. Use this method to implement time-dependent forcing.
Example:
>>> @veros_method
>>> def set_forcing(self, state):
>>> vs = state.variables
>>> current_month = (vs.time / (31 * 24 * 60 * 60)) % 12
>>> vs.surface_taux = vs._windstress_data[:, :, current_month]
"""
pass
[docs]
@abc.abstractmethod
def set_diagnostics(self, state):
"""To be implemented by subclass.
Called before setting up the :ref:`diagnostics <diagnostics>`. Use this method e.g. to
mark additional :ref:`variables <variables>` for output.
Example:
>>> @veros_method
>>> def set_diagnostics(self, state):
>>> state.diagnostics['snapshot'].output_variables += ['drho', 'dsalt', 'dtemp']
"""
pass
[docs]
@abc.abstractmethod
def after_timestep(self, state):
"""Called at the end of each time step. Can be used to define custom, setup-specific
events.
"""
pass
def _ensure_setup_done(self):
if not self._setup_done:
raise RuntimeError("setup() method has to be called before running the model")
def setup(self):
from veros import diagnostics, restart
from veros.core import numerics, external, isoneutral
setup_funcs = (
self.set_parameter,
self.set_grid,
self.set_coriolis,
self.set_topography,
self.set_initial_conditions,
self.set_diagnostics,
self.set_forcing,
self.after_timestep,
)
for f in setup_funcs:
if not is_veros_routine(f):
raise RuntimeError(
f"{f.__name__} method is not a Veros routine. Please make sure to decorate it "
"with @veros_routine and try again."
)
logger.info("Running model setup")
with self.state.timers["setup"]:
with self.state.settings.unlock():
self.set_parameter(self.state)
for setting, value in self.override_settings.items():
setattr(self.state.settings, setting, value)
settings.check_setting_conflicts(self.state.settings)
distributed.validate_decomposition(self.state.dimensions)
self.state.initialize_variables()
self.state.diagnostics.update(diagnostics.create_default_diagnostics(self.state))
for plugin in self._plugin_interfaces:
for diagnostic in plugin.diagnostics:
self.state.diagnostics[diagnostic.name] = diagnostic()
self.set_grid(self.state)
numerics.calc_grid(self.state)
self.set_coriolis(self.state)
numerics.calc_beta(self.state)
self.set_topography(self.state)
numerics.calc_topo(self.state)
self.set_initial_conditions(self.state)
numerics.calc_initial_conditions(self.state)
if self.state.settings.enable_streamfunction:
external.streamfunction_init(self.state)
for plugin in self._plugin_interfaces:
plugin.setup_entrypoint(self.state)
self.set_diagnostics(self.state)
diagnostics.initialize(self.state)
restart.read_restart(self.state)
self.set_forcing(self.state)
isoneutral.check_isoneutral_slope_crit(self.state)
self._setup_done = True
@veros_routine
def step(self, state):
from veros import diagnostics, restart
from veros.core import idemix, eke, tke, momentum, thermodynamics, advection, utilities, isoneutral, numerics
self._ensure_setup_done()
vs = state.variables
settings = state.settings
with state.timers["diagnostics"]:
restart.write_restart(state)
with state.timers["main"]:
with state.timers["forcing"]:
self.set_forcing(state)
if state.settings.enable_idemix:
with state.timers["idemix"]:
idemix.set_idemix_parameter(state)
with state.timers["eke"]:
eke.set_eke_diffusivities(state)
with state.timers["tke"]:
tke.set_tke_diffusivities(state)
with state.timers["momentum"]:
momentum.momentum(state)
with state.timers["thermodynamics"]:
thermodynamics.thermodynamics(state)
if settings.enable_eke or settings.enable_tke or settings.enable_idemix:
with state.timers["advection"]:
advection.calculate_velocity_on_wgrid(state)
with state.timers["eke"]:
if state.settings.enable_eke:
eke.integrate_eke(state)
with state.timers["idemix"]:
if state.settings.enable_idemix:
idemix.integrate_idemix(state)
with state.timers["tke"]:
if state.settings.enable_tke:
tke.integrate_tke(state)
with state.timers["boundary_exchange"]:
vs.u = utilities.enforce_boundaries(vs.u, settings.enable_cyclic_x)
vs.v = utilities.enforce_boundaries(vs.v, settings.enable_cyclic_x)
if settings.enable_tke:
vs.tke = utilities.enforce_boundaries(vs.tke, settings.enable_cyclic_x)
if settings.enable_eke:
vs.eke = utilities.enforce_boundaries(vs.eke, settings.enable_cyclic_x)
if settings.enable_idemix:
vs.E_iw = utilities.enforce_boundaries(vs.E_iw, settings.enable_cyclic_x)
with state.timers["momentum"]:
momentum.vertical_velocity(state)
with state.timers["plugins"]:
for plugin in self._plugin_interfaces:
with state.timers[plugin.name]:
plugin.run_entrypoint(state)
vs.itt = vs.itt + 1
vs.time = vs.time + settings.dt_tracer
self.after_timestep(state)
with state.timers["diagnostics"]:
if not numerics.sanity_check(state):
raise RuntimeError(f"solution diverged at iteration {vs.itt}")
isoneutral.isoneutral_diag_streamfunction(state)
diagnostics.diagnose(state)
diagnostics.output(state)
# NOTE: benchmarks parse this, do not change / remove
logger.debug(" Time step took {:.2f}s", state.timers["main"].last_time)
# permutate time indices
vs.taum1, vs.tau, vs.taup1 = vs.tau, vs.taup1, vs.taum1
[docs]
def run(self, show_progress_bar=None):
"""Main routine of the simulation.
Note:
Make sure to call :meth:`setup` prior to this function.
Arguments:
show_progress_bar (:obj:`bool`, optional): Whether to show fancy progress bar via tqdm.
By default, only show if stdout is a terminal and Veros is running on a single process.
"""
from veros import restart
self._ensure_setup_done()
vs = self.state.variables
settings = self.state.settings
time_length, time_unit = time.format_time(settings.runlen)
logger.info(f"\nStarting integration for {time_length:.1f} {time_unit}")
start_time = vs.time
# disable timers for first iteration
timer_context.active = False
pbar = progress.get_progress_bar(self.state, use_tqdm=show_progress_bar)
try:
with signals.signals_to_exception(), pbar:
while vs.time - start_time < settings.runlen:
self.step(self.state)
if not timer_context.active:
timer_context.active = True
pbar.advance_time(settings.dt_tracer)
except: # noqa: E722
logger.critical(f"Stopping integration at iteration {vs.itt}")
raise
else:
logger.success("Integration done\n")
finally:
restart.write_restart(self.state, force=True)
self._timing_summary()
def _timing_summary(self):
timing_summary = []
timing_summary.extend(
[
"",
"Timing summary:",
"(excluding first iteration)",
"---",
" setup time = {:.2f}s".format(self.state.timers["setup"].total_time),
" main loop time = {:.2f}s".format(self.state.timers["main"].total_time),
" forcing = {:.2f}s".format(self.state.timers["forcing"].total_time),
" momentum = {:.2f}s".format(self.state.timers["momentum"].total_time),
" pressure = {:.2f}s".format(self.state.timers["pressure"].total_time),
" friction = {:.2f}s".format(self.state.timers["friction"].total_time),
" thermodynamics = {:.2f}s".format(self.state.timers["thermodynamics"].total_time),
]
)
if rs.profile_mode:
timing_summary.extend(
[
" lateral mixing = {:.2f}s".format(self.state.timers["isoneutral"].total_time),
" vertical mixing = {:.2f}s".format(self.state.timers["vmix"].total_time),
" equation of state = {:.2f}s".format(self.state.timers["eq_of_state"].total_time),
]
)
timing_summary.extend(
[
" advection = {:.2f}s".format(self.state.timers["advection"].total_time),
" EKE = {:.2f}s".format(self.state.timers["eke"].total_time),
" IDEMIX = {:.2f}s".format(self.state.timers["idemix"].total_time),
" TKE = {:.2f}s".format(self.state.timers["tke"].total_time),
" boundary exchange = {:.2f}s".format(self.state.timers["boundary_exchange"].total_time),
" diagnostics and I/O = {:.2f}s".format(self.state.timers["diagnostics"].total_time),
" plugins = {:.2f}s".format(self.state.timers["plugins"].total_time),
]
)
timing_summary.extend(
[
" {:<22} = {:.2f}s".format(plugin.name, self.state.timers[plugin.name].total_time)
for plugin in self._plugin_interfaces
]
)
logger.debug("\n".join(timing_summary))
if rs.profile_mode:
print_profile_summary(self.state.profile_timers, self.state.timers["main"].total_time)
def print_profile_summary(profile_timers, main_loop_time):
profile_timings = ["", "Profile timings:", "[total time spent (% of main loop)]", "---"]
maxwidth = max(len(k) for k in profile_timers.keys())
profile_format_string = "{{:<{}}} = {{:.2f}}s ({{:.2f}}%)".format(maxwidth)
main_loop_time = max(main_loop_time, 1e-8) # prevent division by 0
for name, timer in profile_timers.items():
this_time = timer.total_time
if this_time == 0:
continue
profile_timings.append(profile_format_string.format(name, this_time, 100 * this_time / main_loop_time))
logger.diagnostic("\n".join(profile_timings))