diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index 57c39bcdf..d3eb736a9 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -1747,8 +1747,10 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op path, pickle_module=pickle_module, **kwargs) else: - torch.save({item: getattr(self, item) for item in self.PRESERVE}, - path, pickle_module=pickle_module, **kwargs) + attributes = {item: getattr(self, item) for item in self.PRESERVE if item != "optimizer"} + optimizer = getattr(self, "optimizer") + attributes["optimizer"] = optimizer.state_dict() + torch.save(attributes, path, pickle_module=pickle_module, **kwargs) def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs): """ Load a torch model from a file.