From 35e0b3055685756fdbaf68328cd7fe4dde3591f9 Mon Sep 17 00:00:00 2001 From: caic99 Date: Tue, 6 Feb 2024 12:32:28 +0800 Subject: [PATCH] extract time step from input data --- deepflame/data.py | 11 +++++------ deepflame/inference.py | 4 ++-- deepflame/trainer.py | 3 ++- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/deepflame/data.py b/deepflame/data.py index 737b466..03c761c 100644 --- a/deepflame/data.py +++ b/deepflame/data.py @@ -36,15 +36,14 @@ def __init__( # Load Dataset self.data = torch.tensor(np.load(data_path)) - # DIMENSION: n * ((T, P, Y[ns], H[ns])_in, (T, P, Y[ns], H[ns])_out) - self.time_step = 1e-7 # TODO: load from dataset + # DIMENSION: n * ((T, P, Y[ns], H[ns])_in, (T, P, Y[ns], H[ns])_out, time_step) self.formation_enthalpies = torch.tensor(np.load(formation_enthalpies_path)) assert ( self.formation_enthalpies.shape[0] == self.n_species ), "n_species in dataset does not match formation_enthalpies" - # self.dims = (1, 2 * (2 + 2 * self.n_species)) + # self.dims = (1, 2 * (2 + 2 * self.n_species) + 1) assert ( - self.data.shape[1] == 4 + 4 * self.n_species + self.data.shape[1] == 4 + 4 * self.n_species + 1 ), "n_species in dataset does not match config file" self.indices = ( 1, # T_in @@ -55,8 +54,9 @@ def __init__( 1, # P_out self.n_species, # Y_label[ns] self.n_species, # H_out[ns] + 1, # dt ) - (T_in, P_in, Y_in, H_in, T_label, P_label, Y_label, H_label) = self.data.split( + (T_in, P_in, Y_in, H_in, T_label, P_label, Y_label, H_label, time_step) = self.data.split( self.indices, dim=1 ) # The mass fraction calculated by cantera is not guaranteed to be positive. @@ -83,7 +83,6 @@ def set_norm_stats(self, key: str, value=None): set_stats(self, "lmbda", lmbda) set_stats(self, "formation_enthalpies") - set_stats(self, "time_step") set_norm_stats(self, "T_in", T_in) set_norm_stats(self, "P_in", P_in) self.Y_t_in = boxcox(Y_in, lmbda) diff --git a/deepflame/inference.py b/deepflame/inference.py index b5064a0..7723fe2 100644 --- a/deepflame/inference.py +++ b/deepflame/inference.py @@ -103,9 +103,9 @@ def load_lightning_model( # Currently `inference()` is called directly from C++, # so we have to explicitly put the model in the scope of this file. # TODO: fix this -module: torch.nn.Module = load_lightning_model().eval() +module: torch.nn.Module = load_lightning_model().eval() # TODO: load model checkpoint from config n_species: int = module.model.formation_enthalpies.shape[0] -time_step: float = module.model.time_step +time_step: float = settings["inferenceDeltaTime"] lmbda: float = module.model.lmbda diff --git a/deepflame/trainer.py b/deepflame/trainer.py index 55a7b3c..f3a5bfc 100644 --- a/deepflame/trainer.py +++ b/deepflame/trainer.py @@ -39,6 +39,7 @@ def training_step(self, batch, batch_idx): P_label, Y_label, H_label, + time_step, ) = batch Y_t_in = boxcox(Y_in, self.model.lmbda) Y_t_label = boxcox(Y_label, self.model.lmbda) @@ -60,7 +61,7 @@ def training_step(self, batch, batch_idx): ) scale = ( - self.model.time_step * 1e13 + time_step * 1e13 ) # prevent overflow introduced by large H and small time_step loss3 = (