Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions deepflame/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ def __init__(

# Load Dataset
self.data = torch.tensor(np.load(data_path))
# DIMENSION: n * ((T, P, Y[ns], H[ns])_in, (T, P, Y[ns], H[ns])_out)
self.time_step = 1e-7 # TODO: load from dataset
# DIMENSION: n * ((T, P, Y[ns], H[ns])_in, (T, P, Y[ns], H[ns])_out, time_step)
self.formation_enthalpies = torch.tensor(np.load(formation_enthalpies_path))
assert (
self.formation_enthalpies.shape[0] == self.n_species
), "n_species in dataset does not match formation_enthalpies"
# self.dims = (1, 2 * (2 + 2 * self.n_species))
# self.dims = (1, 2 * (2 + 2 * self.n_species) + 1)
assert (
self.data.shape[1] == 4 + 4 * self.n_species
self.data.shape[1] == 4 + 4 * self.n_species + 1
), "n_species in dataset does not match config file"
self.indices = (
1, # T_in
Expand All @@ -55,8 +54,9 @@ def __init__(
1, # P_out
self.n_species, # Y_label[ns]
self.n_species, # H_out[ns]
1, # dt
)
(T_in, P_in, Y_in, H_in, T_label, P_label, Y_label, H_label) = self.data.split(
(T_in, P_in, Y_in, H_in, T_label, P_label, Y_label, H_label, time_step) = self.data.split(
self.indices, dim=1
)
# The mass fraction calculated by cantera is not guaranteed to be positive.
Expand All @@ -83,7 +83,6 @@ def set_norm_stats(self, key: str, value=None):

set_stats(self, "lmbda", lmbda)
set_stats(self, "formation_enthalpies")
set_stats(self, "time_step")
set_norm_stats(self, "T_in", T_in)
set_norm_stats(self, "P_in", P_in)
self.Y_t_in = boxcox(Y_in, lmbda)
Expand Down
4 changes: 2 additions & 2 deletions deepflame/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def load_lightning_model(
# Currently `inference()` is called directly from C++,
# so we have to explicitly put the model in the scope of this file.
# TODO: fix this
module: torch.nn.Module = load_lightning_model().eval()
module: torch.nn.Module = load_lightning_model().eval() # TODO: load model checkpoint from config
n_species: int = module.model.formation_enthalpies.shape[0]
time_step: float = module.model.time_step
time_step: float = settings["inferenceDeltaTime"]
lmbda: float = module.model.lmbda


Expand Down
3 changes: 2 additions & 1 deletion deepflame/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def training_step(self, batch, batch_idx):
P_label,
Y_label,
H_label,
time_step,
) = batch
Y_t_in = boxcox(Y_in, self.model.lmbda)
Y_t_label = boxcox(Y_label, self.model.lmbda)
Expand All @@ -60,7 +61,7 @@ def training_step(self, batch, batch_idx):
)

scale = (
self.model.time_step * 1e13
time_step * 1e13
) # prevent overflow introduced by large H and small time_step

loss3 = (
Expand Down