Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions batchflow/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,18 @@ def callable_init(module): # example of a callable for init
'microbatch_size': 16, # size of microbatches at training
}
"""
PRESERVE = [
PRESERVE = set([
'full_config', 'config', 'model',
'inputs_shapes', 'targets_shapes', 'classes',
'loss', 'optimizer', 'scaler', 'decay', 'decay_step',
'sync_counter', 'microbatch_size',
'iteration', 'last_train_info', 'last_predict_info',
'lr_list', 'syncs', 'decay_iters',
'_loss_list', 'loss_list', 'operations'
]
])

PRESERVE_ONNX = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
PRESERVE_OPENVINO = PRESERVE - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])

def __init__(self, config=None):
if not isinstance(config, (dict, Config)):
Expand Down Expand Up @@ -1668,7 +1671,7 @@ def convert_outputs(self, outputs):

# Store model
def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_openvino=None,
batch_size=None, opset_version=13, pickle_module=dill, **kwargs):
batch_size=None, opset_version=13, pickle_module=dill, ignore_attributes=None, **kwargs):
""" Save underlying PyTorch model along with meta parameters (config, device spec, etc).

If `use_onnx` is set to True, then the model is converted to ONNX format and stored in a separate file.
Expand Down Expand Up @@ -1699,6 +1702,8 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
Version of export standard to use.
pickle_module : module
Module to use for pickling.
ignore_attributes : str or iterable, optional
List of attributes to ignore when pickling (e.g. 'optimizer')
kwargs : dict
Other keyword arguments, passed directly to :func:`torch.save`.
"""
Expand All @@ -1710,6 +1715,12 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
self.model = self.model.module

if isinstance(ignore_attributes, str):
ignore_attributes = [ignore_attributes]
elif ignore_attributes is None:
ignore_attributes = []
ignore_attributes = set(ignore_attributes)

if use_onnx:
if batch_size is None:
raise ValueError('Specify valid `batch_size`, used for model inference!')
Expand All @@ -1719,7 +1730,8 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
torch.onnx.export(self.model.eval(), inputs, path_onnx, opset_version=opset_version)

# Save the rest of parameters
preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
preserved = self.PRESERVE_ONNX - ignore_attributes

preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'onnx': True, 'path_onnx': path_onnx, 'onnx_batch_size': batch_size, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)
Expand All @@ -1741,13 +1753,14 @@ def save(self, path, use_onnx=False, path_onnx=None, use_openvino=False, path_op
ov.save_model(model, output_model=path_openvino)

# Save the rest of parameters
preserved = set(self.PRESERVE) - set(['model', 'loss', 'optimizer', 'scaler', 'decay'])
preserved = self.PRESERVE_OPENVINO - ignore_attributes
preserved_dict = {item: getattr(self, item) for item in preserved}
torch.save({'openvino': True, 'path_openvino': path_openvino, **preserved_dict},
path, pickle_module=pickle_module, **kwargs)

else:
torch.save({item: getattr(self, item) for item in self.PRESERVE},
preserved = set(self.PRESERVE) - set(ignore_attributes)
torch.save({item: getattr(self, item) for item in preserved},
path, pickle_module=pickle_module, **kwargs)

def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill, **kwargs):
Expand Down
Loading