diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 85422e5..44b93b7 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -57,6 +57,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +import pickle import torch import yaml @@ -200,6 +201,16 @@ def update(self, state: Valuer) -> typing.Iterable[Var]: precipitation_mm_h ) + def serialize(self): + return { + "h": self.h_t.numpy(), + "c": self.c_t.numpy() + } + + def deserialize(self, data: dict): + self.h_t = torch.from_numpy(data["h"]) + self.c_t = torch.from_numpy(data["c"]) + def bmi_array(arr: list[float]) -> npt.NDArray: """Trivial wrapper function to ensure the expected numpy array datatype is used.""" @@ -419,6 +430,10 @@ def __init__(self) -> None: self.cfg_bmi: dict[str, typing.Any] self.ensemble_members: list[EnsembleMember] + # statically stored seriaized data + self._serialized_size = np.array([0], dtype=np.uint64) + self._serialized = np.array([], dtype=np.uint8) + def initialize(self, config_file: str) -> None: # configure the Error Warning and Trapping System logger @@ -515,16 +530,38 @@ def get_var_grid(self, name: str) -> int: return 0 def get_var_type(self, name: str) -> str: + if name == "serialization_state": + return self._serialized.dtype.name + elif name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.name + elif name == "serialization_free": + return np.dtype(np.intc).name + elif name == "reset_time": + return np.dtype(np.double).name return self.get_value_ptr(name).dtype.name def get_var_units(self, name: str) -> str: return first_containing(name, self._outputs, self._dynamic_inputs).unit(name) def get_var_itemsize(self, name: str) -> int: + if name == "serialization_state": + return self._serialized.dtype.itemsize + if name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.itemsize + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize return self.get_value_ptr(name).itemsize def get_var_nbytes(self, name: str) -> int: - return self.get_var_itemsize(name) * len(self.get_value_ptr(name)) + if name == "serialization_create": + return self._serialized_size.nbytes + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize + return self.get_value_ptr(name).nbytes def get_var_location(self, name: str) -> str: # raises KeyError on failure @@ -553,6 +590,10 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray: def get_value_ptr(self, name: str) -> np.ndarray: """Returns a _reference_ to a variable's np.NDArray.""" + if name == "serialization_state": + return self._serialized + elif name == "serialization_size": + return self._serialized_size return first_containing(name, self._outputs, self._dynamic_inputs).value(name) def get_value_at_indices( @@ -563,9 +604,18 @@ def get_value_at_indices( ).value_at_indices(name, dest, inds) def set_value(self, name: str, src: np.ndarray) -> None: - return first_containing(name, self._outputs, self._dynamic_inputs).set_value( - name, src - ) + if name == "serialization_state": + self._deserialize(src) + elif name == "serialization_create": + self._serialize() + elif name == "serialization_free": + self._free_serialized() + elif name == "reset_time": + self._timestep = 0 + else: + return first_containing(name, self._outputs, self._dynamic_inputs).set_value( + name, src + ) def set_value_at_indices( self, name: str, inds: np.ndarray, src: np.ndarray @@ -593,6 +643,38 @@ def get_grid_type(self, grid: int) -> str: return "scalar" raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0") + def _serialize(self): + """Convert all dynamic properties that can change after the `bmi_LSTM` has had `initialize()` called into an object that can be serialized through `pickle`. + Then, set the BMI's `_serialized` property to the byte representation of that pickled data and adjust the static `_serialized_size` property.""" + data = { + "dynamic_inputs": self._dynamic_inputs.serialize(), + "static_inputs": self._static_inputs.serialize(), + "outputs": self._outputs.serialize(), + "ensemble": [em.serialize() for em in self.ensemble_members], + "timestep": self._timestep, + } + serialized = pickle.dumps(data) + self._serialized = np.array(bytearray(serialized), dtype=np.uint8) + self._serialized_size[0] = len(self._serialized) + + def _deserialize(self, array: np.ndarray): + """Interpret the bytes of the numpy array as previously pickled data from `_serialize()` and update the current values. + No data structure check will be made on the input array or loaded bytes. It will be assumed that the input data is of the same structure as what is generated from `_serialize()`.""" + data = bytes(array) + deserialized = pickle.loads(data) + self._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) + self._static_inputs.deserialize(deserialized["static_inputs"]) + self._outputs.deserialize(deserialized["outputs"]) + for bmi_em, data_em in zip(self.ensemble_members, deserialized["ensemble"], strict=True): + bmi_em.deserialize(data_em) + self._timestep = deserialized["timestep"] + self._free_serialized() + + def _free_serialized(self): + """Clear the current serialized data and set the size property value to 0.""" + self._serialized_size[0] = 0 + self._serialized = np.array([], dtype=self._serialized.dtype) + def coerce_config(cfg: dict[str, typing.Any]): for key, val in cfg.items(): diff --git a/lstm/model_state.py b/lstm/model_state.py index c13fcfc..bbacbd6 100644 --- a/lstm/model_state.py +++ b/lstm/model_state.py @@ -94,6 +94,23 @@ def __iter__(self) -> typing.Iterator[Var]: def __len__(self) -> int: return len(self._name_mapping) + def serialize(self): + """Return the State represented as a list of dicts representing the `Var` properties.""" + return [ + {"name": var.name, "unit": var.unit, "value": var.value} + for var in self._name_mapping.values() + ] + + def deserialize(self, values: list): + """Replace the current Vars with values from the intput list.""" + self._name_mapping.clear() + for var in values: + self._name_mapping[var["name"]] = Var( + name=var["name"], + unit=var["unit"], + value=var["value"] + ) + class StateFacade: """