From f70ac5da72e5ebf06dd7679371b63c68f1d87615 Mon Sep 17 00:00:00 2001 From: "a.shagitov" Date: Mon, 20 Jan 2025 13:35:55 +0300 Subject: [PATCH] Changed the way the optimizer is saved: only its state_dict is pickled --- batchflow/models/torch/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index 57c39bcdf..d3eb736a9 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -1747,8 +1747,10 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op path, pickle_module=pickle_module, **kwargs) else: - torch.save({item: getattr(self, item) for item in self.PRESERVE}, - path, pickle_module=pickle_module, **kwargs) + attributes = {item: getattr(self, item) for item in self.PRESERVE if item != "optimizer"} + optimizer = getattr(self, "optimizer") + attributes["optimizer"] = optimizer.state_dict() + torch.save(attributes, path, pickle_module=pickle_module, **kwargs) def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): """ Load a torch model from a file.