diff --git a/batchflow/models/torch/base.py b/batchflow/models/torch/base.py index 57c39bcdf..0e7920e69 100755 --- a/batchflow/models/torch/base.py +++ b/batchflow/models/torch/base.py @@ -389,7 +389,7 @@ def callable_init(module): # example of a callable for init 'microbatch_size': 16, # size of microbatches at training } """ - PRESERVE = [ + PRESERVE = set([ 'full_config', 'config', 'model', 'inputs_shapes', 'targets_shapes', 'classes', 'loss', 'optimizer', 'scaler', 'decay', 'decay_step', @@ -397,7 +397,10 @@ def callable_init(module): # example of a callable for init 'iteration', 'last_train_info', 'last_predict_info', 'lr_list', 'syncs', 'decay_iters', '_loss_list', 'loss_list', 'operations' - ] + ]) + + PRESERVE_ONNX = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) + PRESERVE_OPENVINO = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) def __init__(self, config=None): if not isinstance(config, (dict, Config)): @@ -1668,7 +1671,7 @@ def convert_outputs(self, outputs): # Store model def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None, - batch_size=None, opset_version=13, pickle_module=dill, **kwargs): + batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs): """ Save underlying PyTorch model along with meta parameters (config, device spec, etc). If `use_onnx` is set to True, then the model is converted to ONNX format and stored in a separate file. @@ -1699,6 +1702,8 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op Version of export standard to use. pickle_module : module Module to use for pickling. + ignore_attributes : str or iterable, optional + List of attributes to ignore when pickling (e.g. 'optimizer') kwargs : dict Other keyword arguments, passed directly to :func:`torch.save`. """ @@ -1710,6 +1715,12 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model = self.model.module + if isinstance(ignore_attributes, str): + ignore_attributes = [ignore_attributes] + elif ignore_attributes is None: + ignore_attributes = [] + ignore_attributes = set(ignore_attributes) + if use_onnx: if batch_size is None: raise ValueError('Specify valid `batch_size`, used for model inference!') @@ -1719,7 +1730,8 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version) # Save the rest of parameters - preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) + preserved = self.PRESERVE_ONNX - ignore_attributes + preserved_dict = {item: getattr(self, item) for item in preserved} torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict}, path, pickle_module=pickle_module, **kwargs) @@ -1741,13 +1753,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op ov.save_model(model, output_model=path_openvino) # Save the rest of parameters - preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay']) + preserved = self.PRESERVE_OPENVINO - ignore_attributes preserved_dict = {item: getattr(self, item) for item in preserved} torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict}, path, pickle_module=pickle_module, **kwargs) else: - torch.save({item: getattr(self, item) for item in self.PRESERVE}, + preserved = set(self.PRESERVE) - set(ignore_attributes) + torch.save({item: getattr(self, item) for item in preserved}, path, pickle_module=pickle_module, **kwargs) def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):