Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install dependencies
run: |
pip install -U pip # upgrade pip
pip install '.[develop]'
pip install '.[develop]' './lstm_ewts'
- name: Echo dependency versions
run: |
pip freeze
Expand Down
62 changes: 40 additions & 22 deletions lstm/bmi_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@

from . import nextgen_cuda_lstm
from .base import BmiBase
from .logger import configure_logging, logger
from lstm_ewts import configure_logging, MODULE_NAME
from .model_state import State, StateFacade, Var

import logging
LOG = logging.getLogger(MODULE_NAME)

# -------------- Dynamic Attributes -----------------------------
_dynamic_input_vars = [
("land_surface_radiation~incoming~longwave__energy_flux", "W m-2"),
Expand All @@ -85,6 +88,7 @@
_output_vars = [
("land_surface_water__runoff_volume_flux", "m3 s-1"),
("land_surface_water__runoff_depth", "m"),
("precipitation_rate", "mm s-1"),
]

# -------------- Name Mappings -----------------------------
Expand Down Expand Up @@ -173,6 +177,9 @@ def update(self, state: Valuer) -> typing.Iterable[Var]:
with torch.no_grad():
inputs = gather_inputs(state, self.input_names)

# Retrieve precipitation value for output
precipitation_mm_h = state.value("atmosphere_water__liquid_equivalent_precipitation_rate")

scaled = scale_inputs(
inputs, self.scalars.input_mean, self.scalars.input_std
)
Expand All @@ -190,6 +197,7 @@ def update(self, state: Valuer) -> typing.Iterable[Var]:
self.scalars.output_mean,
self.scalars.output_std,
self.output_scaling_factor_cms,
precipitation_mm_h
)


Expand Down Expand Up @@ -282,7 +290,7 @@ def initialize_lstm(cfg: dict[str, typing.Any]) -> nextgen_cuda_lstm.Nextgen_Cud
def gather_inputs(
state: Valuer, internal_input_names: typing.Iterable[str]
) -> npt.NDArray:
logger.debug("Collecting LSTM inputs ...")
LOG.debug("Collecting LSTM inputs ...")

input_list = []
for lstm_name in internal_input_names:
Expand All @@ -291,29 +299,29 @@ def gather_inputs(
assert value.size == 1, "`value` should a single scalar in a 1d array"
input_list.append(value[0])

logger.debug(f" {lstm_name=}")
logger.debug(f" {bmi_name=}")
logger.debug(f" {type(value)=}")
logger.debug(f" {value=}")
LOG.debug(f" {lstm_name=}")
LOG.debug(f" {bmi_name=}")
LOG.debug(f" {type(value)=}")
LOG.debug(f" {value=}")

collected = bmi_array(input_list)
logger.debug(f"Collected inputs: {collected}")
LOG.debug(f"Collected inputs: %s", collected)
return collected


def scale_inputs(
input: npt.NDArray, mean: npt.NDArray, std: npt.NDArray
) -> npt.NDArray:
logger.debug("Normalizing the tensor...")
logger.debug(" input_mean =", mean)
logger.debug(" input_std =", std)
LOG.debug("Normalizing the tensor...")
LOG.debug(" input_mean = %s", mean)
LOG.debug(" input_std = %s", std)

# Center and scale the input values for use in torch
input_array_scaled = (input - mean) / std
logger.debug(f"### input_array ={input}")
logger.debug(f"### dtype(input_array) ={input.dtype}")
logger.debug(f"### type(input_array_scaled) ={type(input_array_scaled)}")
logger.debug(f"### dtype(input_array_scaled) ={input_array_scaled.dtype}")
LOG.debug("### input_array = %s", input)
LOG.debug("### dtype(input_array) = %s", input.dtype)
LOG.debug("### type(input_array_scaled) = %s", type(input_array_scaled))
LOG.debug("### dtype(input_array_scaled) = %s", input_array_scaled.dtype)
return input_array_scaled


Expand All @@ -323,8 +331,9 @@ def scale_outputs(
output_mean: npt.NDArray,
output_std: npt.NDArray,
output_scale_factor_cms: float,
precipitation_value: npt.NDArray,
):
logger.debug(f"model output: {output[0, 0, 0].numpy().tolist()}")
LOG.debug(f"model output: {output[0, 0, 0].numpy().tolist()}")

if cfg["target_variables"][0] in ["qobs_mm_per_hour", "QObs(mm/hr)", "QObs(mm/h)"]:
surface_runoff_mm = output[0, 0, 0].numpy() * output_std + output_mean
Expand All @@ -347,6 +356,9 @@ def scale_outputs(
# (1/1000) * (self.cfg_bmi['area_sqkm'] * 1000*1000) * (1/3600)
surface_runoff_volume_m3_s = surface_runoff_mm * output_scale_factor_cms

# Convert precipitation for mm/h to mm/s for output
precip_mms = precipitation_value[0] / 3600.0

# TODO: aaraney: consider making this into a class or closure to avoid so
# many small allocations.
yield from (
Expand All @@ -360,6 +372,11 @@ def scale_outputs(
unit="m3 s-1",
value=bmi_array([surface_runoff_volume_m3_s]),
),
Var(
name="precipitation_rate",
unit="mm s-1",
value=bmi_array([precip_mms])
),
)


Expand Down Expand Up @@ -403,16 +420,17 @@ def __init__(self) -> None:
self.ensemble_members: list[EnsembleMember]

def initialize(self, config_file: str) -> None:

# configure the Error Warning and Trapping System logger
configure_logging()

LOG.info(f"Initializing with {config_file}")

# read and setup main configuration file
with open(config_file, "r") as fp:
self.cfg_bmi = yaml.safe_load(fp)
coerce_config(self.cfg_bmi)

# TODO: aaraney: config logging levels to python logging levels
# setup logging
# self.cfg_bmi["verbose"]
configure_logging()

# ----------- The output is area normalized, this is needed to un-normalize it
# mm->m km2 -> m2 hour->s
output_factor_cms = (
Expand Down Expand Up @@ -458,15 +476,15 @@ def update(self) -> None:
def update_until(self, time: float) -> None:
if time <= self.get_current_time():
current_time = self.get_current_time()
logger.warning(f"no update performed: {time=} <= {current_time=}")
LOG.warning(f"no update performed: {time=} <= {current_time=}")
return None

n_steps, remainder = divmod(
time - self.get_current_time(), self.get_time_step()
)

if remainder != 0:
logger.warning(
LOG.warning(
f"time is not multiple of time step size. updating until: {time - remainder=} "
)

Expand Down
Loading