Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions lstm/bmi_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
import pickle
import torch
import yaml

Expand Down Expand Up @@ -200,6 +201,16 @@ def update(self, state: Valuer) -> typing.Iterable[Var]:
precipitation_mm_h
)

def serialize(self):
return {
"h": self.h_t.numpy(),
"c": self.c_t.numpy()
}

def deserialize(self, data: dict):
self.h_t = torch.from_numpy(data["h"])
self.c_t = torch.from_numpy(data["c"])


def bmi_array(arr: list[float]) -> npt.NDArray:
"""Trivial wrapper function to ensure the expected numpy array datatype is used."""
Expand Down Expand Up @@ -419,6 +430,10 @@ def __init__(self) -> None:
self.cfg_bmi: dict[str, typing.Any]
self.ensemble_members: list[EnsembleMember]

# statically stored seriaized data
self._serialized_size = np.array([0], dtype=np.uint64)
self._serialized = np.array([], dtype=np.uint8)

def initialize(self, config_file: str) -> None:

# configure the Error Warning and Trapping System logger
Expand Down Expand Up @@ -515,16 +530,38 @@ def get_var_grid(self, name: str) -> int:
return 0

def get_var_type(self, name: str) -> str:
if name == "serialization_state":
return self._serialized.dtype.name
elif name == "serialization_size" or name == "serialization_create":
return self._serialized_size.dtype.name
elif name == "serialization_free":
return np.dtype(np.intc).name
elif name == "reset_time":
return np.dtype(np.double).name
return self.get_value_ptr(name).dtype.name

def get_var_units(self, name: str) -> str:
return first_containing(name, self._outputs, self._dynamic_inputs).unit(name)

def get_var_itemsize(self, name: str) -> int:
if name == "serialization_state":
return self._serialized.dtype.itemsize
if name == "serialization_size" or name == "serialization_create":
return self._serialized_size.dtype.itemsize
if name == "serialization_free":
return np.dtype(np.intc).itemsize
if name == "reset_time":
return np.dtype(np.double).itemsize
return self.get_value_ptr(name).itemsize

def get_var_nbytes(self, name: str) -> int:
return self.get_var_itemsize(name) * len(self.get_value_ptr(name))
if name == "serialization_create":
return self._serialized_size.nbytes
if name == "serialization_free":
return np.dtype(np.intc).itemsize
if name == "reset_time":
return np.dtype(np.double).itemsize
return self.get_value_ptr(name).nbytes

def get_var_location(self, name: str) -> str:
# raises KeyError on failure
Expand Down Expand Up @@ -553,6 +590,10 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray:

def get_value_ptr(self, name: str) -> np.ndarray:
"""Returns a _reference_ to a variable's np.NDArray."""
if name == "serialization_state":
return self._serialized
elif name == "serialization_size":
return self._serialized_size
return first_containing(name, self._outputs, self._dynamic_inputs).value(name)

def get_value_at_indices(
Expand All @@ -563,9 +604,18 @@ def get_value_at_indices(
).value_at_indices(name, dest, inds)

def set_value(self, name: str, src: np.ndarray) -> None:
return first_containing(name, self._outputs, self._dynamic_inputs).set_value(
name, src
)
if name == "serialization_state":
self._deserialize(src)
elif name == "serialization_create":
self._serialize()
elif name == "serialization_free":
self._free_serialized()
elif name == "reset_time":
self._timestep = 0
else:
return first_containing(name, self._outputs, self._dynamic_inputs).set_value(
name, src
)

def set_value_at_indices(
self, name: str, inds: np.ndarray, src: np.ndarray
Expand Down Expand Up @@ -593,6 +643,38 @@ def get_grid_type(self, grid: int) -> str:
return "scalar"
raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0")

def _serialize(self):
"""Convert all dynamic properties that can change after the `bmi_LSTM` has had `initialize()` called into an object that can be serialized through `pickle`.
Then, set the BMI's `_serialized` property to the byte representation of that pickled data and adjust the static `_serialized_size` property."""
data = {
"dynamic_inputs": self._dynamic_inputs.serialize(),
"static_inputs": self._static_inputs.serialize(),
"outputs": self._outputs.serialize(),
"ensemble": [em.serialize() for em in self.ensemble_members],
"timestep": self._timestep,
}
serialized = pickle.dumps(data)
self._serialized = np.array(bytearray(serialized), dtype=np.uint8)
self._serialized_size[0] = len(self._serialized)

def _deserialize(self, array: np.ndarray):
"""Interpret the bytes of the numpy array as previously pickled data from `_serialize()` and update the current values.
No data structure check will be made on the input array or loaded bytes. It will be assumed that the input data is of the same structure as what is generated from `_serialize()`."""
data = bytes(array)
deserialized = pickle.loads(data)
self._dynamic_inputs.deserialize(deserialized["dynamic_inputs"])
self._static_inputs.deserialize(deserialized["static_inputs"])
self._outputs.deserialize(deserialized["outputs"])
for bmi_em, data_em in zip(self.ensemble_members, deserialized["ensemble"], strict=True):
bmi_em.deserialize(data_em)
self._timestep = deserialized["timestep"]
self._free_serialized()

def _free_serialized(self):
"""Clear the current serialized data and set the size property value to 0."""
self._serialized_size[0] = 0
self._serialized = np.array([], dtype=self._serialized.dtype)


def coerce_config(cfg: dict[str, typing.Any]):
for key, val in cfg.items():
Expand Down
17 changes: 17 additions & 0 deletions lstm/model_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@ def __iter__(self) -> typing.Iterator[Var]:
def __len__(self) -> int:
return len(self._name_mapping)

def serialize(self):
"""Return the State represented as a list of dicts representing the `Var` properties."""
return [
{"name": var.name, "unit": var.unit, "value": var.value}
for var in self._name_mapping.values()
]

def deserialize(self, values: list):
"""Replace the current Vars with values from the intput list."""
self._name_mapping.clear()
for var in values:
self._name_mapping[var["name"]] = Var(
name=var["name"],
unit=var["unit"],
value=var["value"]
)


class StateFacade:
"""
Expand Down