From 0691d769ed79ebdf0f5442e49996ce820d949aae Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Tue, 27 Jan 2026 08:54:03 -0500 Subject: [PATCH 1/3] Implement serialization --- lstm/bmi_lstm.py | 88 ++++++++++++++++++++++++++++++++++++++++++--- lstm/model_state.py | 17 +++++++++ 2 files changed, 101 insertions(+), 4 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 85422e5..7e0aaf1 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -57,6 +57,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +import pickle import torch import yaml @@ -200,6 +201,16 @@ def update(self, state: Valuer) -> typing.Iterable[Var]: precipitation_mm_h ) + def serialize(self): + return { + "h": self.h_t.numpy(), + "c": self.c_t.numpy() + } + + def deserialize(self, data: dict): + self.h_t = torch.from_numpy(data["h"]) + self.c_t = torch.from_numpy(data["c"]) + def bmi_array(arr: list[float]) -> npt.NDArray: """Trivial wrapper function to ensure the expected numpy array datatype is used.""" @@ -419,6 +430,10 @@ def __init__(self) -> None: self.cfg_bmi: dict[str, typing.Any] self.ensemble_members: list[EnsembleMember] + # statically stored seriaized data + self._serialized_size = np.array([0], dtype=np.uint64) + self._serialized = np.array([], dtype=np.uint8) + def initialize(self, config_file: str) -> None: # configure the Error Warning and Trapping System logger @@ -515,16 +530,38 @@ def get_var_grid(self, name: str) -> int: return 0 def get_var_type(self, name: str) -> str: + if name == "serialization_state": + return self._serialized.dtype.name + elif name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.name + elif name == "serialization_free": + return np.dtype(np.intc).name + elif name == "reset_time": + return np.dtype(np.double).name return self.get_value_ptr(name).dtype.name def get_var_units(self, name: str) -> str: return first_containing(name, self._outputs, self._dynamic_inputs).unit(name) def get_var_itemsize(self, name: str) -> int: + if name == "serialization_state": + return self._serialized.dtype.itemsize + if name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.itemsize + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize return self.get_value_ptr(name).itemsize def get_var_nbytes(self, name: str) -> int: - return self.get_var_itemsize(name) * len(self.get_value_ptr(name)) + if name == "serialization_create": + return self._serialized_size.nbytes + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize + return self.get_value_ptr(name).nbytes def get_var_location(self, name: str) -> str: # raises KeyError on failure @@ -553,6 +590,10 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray: def get_value_ptr(self, name: str) -> np.ndarray: """Returns a _reference_ to a variable's np.NDArray.""" + if name == "serialization_state": + return self._serialized + elif name == "serialization_size": + return self._serialized_size return first_containing(name, self._outputs, self._dynamic_inputs).value(name) def get_value_at_indices( @@ -563,9 +604,18 @@ def get_value_at_indices( ).value_at_indices(name, dest, inds) def set_value(self, name: str, src: np.ndarray) -> None: - return first_containing(name, self._outputs, self._dynamic_inputs).set_value( - name, src - ) + if name == "serialization_state": + deserialize_bmi(src, self) + elif name == "serialization_create": + serialize_bmi(self) + elif name == "serialization_free": + free_serialized_bmi(self) + elif name == "reset_time": + self._timestep = 0 + else: + return first_containing(name, self._outputs, self._dynamic_inputs).set_value( + name, src + ) def set_value_at_indices( self, name: str, inds: np.ndarray, src: np.ndarray @@ -594,6 +644,36 @@ def get_grid_type(self, grid: int) -> str: raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0") +def serialize_bmi(bmi: bmi_LSTM): + data = { + "dynamic_inputs": bmi._dynamic_inputs.serialize(), + "static_inputs": bmi._static_inputs.serialize(), + "outputs": bmi._outputs.serialize(), + "ensemble": [em.serialize() for em in bmi.ensemble_members], + "timestep": bmi._timestep, + } + serialized = pickle.dumps(data) + bmi._serialized = np.array(bytearray(serialized), dtype=np.uint8) + bmi._serialized_size[0] = len(bmi._serialized) + + +def deserialize_bmi(array: np.ndarray, bmi: bmi_LSTM): + data = bytes(array) + deserialized = pickle.loads(data) + bmi._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) + bmi._static_inputs.deserialize(deserialized["static_inputs"]) + bmi._outputs.deserialize(deserialized["outputs"]) + for bmi_em, data_em in zip(bmi.ensemble_members, deserialized["ensemble"], strict=True): + bmi_em.deserialize(data_em) + bmi._timestep = deserialized["timestep"] + free_serialized_bmi(bmi) + + +def free_serialized_bmi(bmi: bmi_LSTM): + bmi._serialized_size[0] = 0 + bmi._serialized = np.array([], dtype=bmi._serialized.dtype) + + def coerce_config(cfg: dict[str, typing.Any]): for key, val in cfg.items(): # Handle 'train_cfg_file' specifically to ensure it is always a list diff --git a/lstm/model_state.py b/lstm/model_state.py index c13fcfc..575d46d 100644 --- a/lstm/model_state.py +++ b/lstm/model_state.py @@ -94,6 +94,23 @@ def __iter__(self) -> typing.Iterator[Var]: def __len__(self) -> int: return len(self._name_mapping) + def serialize(self): + """Return the State represented as a list.""" + return [ + {"name": var.name, "unit": var.unit, "value": var.value} + for var in self._name_mapping.values() + ] + + def deserialize(self, values: list): + """Replace the current Vars with values from the intput list.""" + self._name_mapping.clear() + for var in values: + self._name_mapping[var["name"]] = Var( + name=var["name"], + unit=var["unit"], + value=var["value"] + ) + class StateFacade: """ From 72c2ab73d80652aac18a3d35b2b5af212eb88122 Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Mon, 2 Feb 2026 11:15:34 -0500 Subject: [PATCH 2/3] Add docstrings --- lstm/bmi_lstm.py | 66 +++++++++++++++++++++++---------------------- lstm/model_state.py | 2 +- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 7e0aaf1..1653f16 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -605,11 +605,11 @@ def get_value_at_indices( def set_value(self, name: str, src: np.ndarray) -> None: if name == "serialization_state": - deserialize_bmi(src, self) + self._deserialize(src, self) elif name == "serialization_create": - serialize_bmi(self) + self._serialize(self) elif name == "serialization_free": - free_serialized_bmi(self) + self._free_serialized(self) elif name == "reset_time": self._timestep = 0 else: @@ -643,35 +643,37 @@ def get_grid_type(self, grid: int) -> str: return "scalar" raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0") - -def serialize_bmi(bmi: bmi_LSTM): - data = { - "dynamic_inputs": bmi._dynamic_inputs.serialize(), - "static_inputs": bmi._static_inputs.serialize(), - "outputs": bmi._outputs.serialize(), - "ensemble": [em.serialize() for em in bmi.ensemble_members], - "timestep": bmi._timestep, - } - serialized = pickle.dumps(data) - bmi._serialized = np.array(bytearray(serialized), dtype=np.uint8) - bmi._serialized_size[0] = len(bmi._serialized) - - -def deserialize_bmi(array: np.ndarray, bmi: bmi_LSTM): - data = bytes(array) - deserialized = pickle.loads(data) - bmi._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) - bmi._static_inputs.deserialize(deserialized["static_inputs"]) - bmi._outputs.deserialize(deserialized["outputs"]) - for bmi_em, data_em in zip(bmi.ensemble_members, deserialized["ensemble"], strict=True): - bmi_em.deserialize(data_em) - bmi._timestep = deserialized["timestep"] - free_serialized_bmi(bmi) - - -def free_serialized_bmi(bmi: bmi_LSTM): - bmi._serialized_size[0] = 0 - bmi._serialized = np.array([], dtype=bmi._serialized.dtype) + def _serialize(self): + """Convert all dynamic properties that can change after the `bmi_LSTM` has had `initialize()` called into an object that can be serialized through `pickle`. + Then, set the BMI's `_serialized` property to the byte representation of that pickled data and adjust the static `_serialized_size` property.""" + data = { + "dynamic_inputs": self._dynamic_inputs.serialize(), + "static_inputs": self._static_inputs.serialize(), + "outputs": self._outputs.serialize(), + "ensemble": [em.serialize() for em in self.ensemble_members], + "timestep": self._timestep, + } + serialized = pickle.dumps(data) + self._serialized = np.array(bytearray(serialized), dtype=np.uint8) + self._serialized_size[0] = len(self._serialized) + + def _deserialize(self, array: np.ndarray): + """Interpret the bytes of the numpy array as previously pickled data from `_serialize()` and update the current values. + No data structure check will be made on the input array or loaded bytes. It will be assumed that the input data is of the same structure as what is generated from `_serialize()`.""" + data = bytes(array) + deserialized = pickle.loads(data) + self._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) + self._static_inputs.deserialize(deserialized["static_inputs"]) + self._outputs.deserialize(deserialized["outputs"]) + for bmi_em, data_em in zip(self.ensemble_members, deserialized["ensemble"], strict=True): + bmi_em.deserialize(data_em) + self._timestep = deserialized["timestep"] + self._free_serialized() + + def _free_serialized(self): + """Clear the current serialized data and set the size property value to 0.""" + self._serialized_size[0] = 0 + self._serialized = np.array([], dtype=self._serialized.dtype) def coerce_config(cfg: dict[str, typing.Any]): diff --git a/lstm/model_state.py b/lstm/model_state.py index 575d46d..bbacbd6 100644 --- a/lstm/model_state.py +++ b/lstm/model_state.py @@ -95,7 +95,7 @@ def __len__(self) -> int: return len(self._name_mapping) def serialize(self): - """Return the State represented as a list.""" + """Return the State represented as a list of dicts representing the `Var` properties.""" return [ {"name": var.name, "unit": var.unit, "value": var.value} for var in self._name_mapping.values() From 650d5ed6a3986c8fda103a121f6da724cb7220a7 Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Wed, 18 Feb 2026 15:07:00 -0500 Subject: [PATCH 3/3] Remove self arguments from serialization methods --- lstm/bmi_lstm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 1653f16..44b93b7 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -605,11 +605,11 @@ def get_value_at_indices( def set_value(self, name: str, src: np.ndarray) -> None: if name == "serialization_state": - self._deserialize(src, self) + self._deserialize(src) elif name == "serialization_create": - self._serialize(self) + self._serialize() elif name == "serialization_free": - self._free_serialized(self) + self._free_serialized() elif name == "reset_time": self._timestep = 0 else: