From 2885ec29ac41d685ffd2f9c146a8b18739a6fedf Mon Sep 17 00:00:00 2001 From: Benjamin Bennett <125957+benbennett@users.noreply.github.com> Date: Tue, 23 Sep 2025 23:54:43 -0500 Subject: [PATCH] Detach loss tensors before scalar conversion --- models/base_model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index d84ded8..a740b7d 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -123,11 +123,16 @@ def get_current_visuals(self): def get_current_losses(self): """Return traning losses / errors. simple.py will print out these errors on console, and save them to a file""" - errors_ret = OrderedDict() - for name in self.loss_names: - if isinstance(name, str): - errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number - return errors_ret + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + loss_value = getattr(self, 'loss_' + name) + if isinstance(loss_value, torch.Tensor): + # detach() avoids inadvertently keeping autograd history when converting tensors to scalars + errors_ret[name] = loss_value.detach().item() + else: + errors_ret[name] = float(loss_value) + return errors_ret def save_networks(self, epoch): """Save all the networks to the disk.