diff --git a/classification/detect_from_video.py b/classification/detect_from_video.py index 19c31ee..4bd705e 100644 --- a/classification/detect_from_video.py +++ b/classification/detect_from_video.py @@ -132,12 +132,13 @@ def test_full_image_network(video_path, model_path, output_path, face_detector = dlib.get_frontal_face_detector() # Load model - model, *_ = model_selection(modelname='xception', num_out_classes=2) + pretrained = (model_path is None) + model, *_ = model_selection(modelname='xception', num_out_classes=2, pretrained=pretrained) if model_path is not None: model = torch.load(model_path) print('Model found in {}'.format(model_path)) else: - print('No model found, initializing random model.') + print('No model found, using pretrained model.') if cuda: model = model.cuda() @@ -221,7 +222,7 @@ def test_full_image_network(video_path, model_path, output_path, p = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) p.add_argument('--video_path', '-i', type=str) - p.add_argument('--model_path', '-mi', type=str, default=None) + p.add_argument('--model_path', '-m', type=str, default=None) p.add_argument('--output_path', '-o', type=str, default='.') p.add_argument('--start_frame', type=int, default=0) diff --git a/classification/network/models.py b/classification/network/models.py index 06b010e..823c76b 100644 --- a/classification/network/models.py +++ b/classification/network/models.py @@ -38,11 +38,11 @@ class TransferModel(nn.Module): Simple transfer learning model that takes an imagenet pretrained model with a fc layer as base model and retrains a new fc layer for num_out_classes """ - def __init__(self, modelchoice, num_out_classes=2, dropout=0.0): + def __init__(self, modelchoice, num_out_classes=2, dropout=0.0, pretrained=True): super(TransferModel, self).__init__() self.modelchoice = modelchoice if modelchoice == 'xception': - self.model = return_pytorch04_xception() + self.model = return_pytorch04_xception(pretrained) # Replace fc num_ftrs = self.model.last_linear.in_features if not dropout: @@ -55,9 +55,9 @@ def __init__(self, modelchoice, num_out_classes=2, dropout=0.0): ) elif modelchoice == 'resnet50' or modelchoice == 'resnet18': if modelchoice == 'resnet50': - self.model = torchvision.models.resnet50(pretrained=True) + self.model = torchvision.models.resnet50(pretrained=pretrained) if modelchoice == 'resnet18': - self.model = torchvision.models.resnet18(pretrained=True) + self.model = torchvision.models.resnet18(pretrained=pretrained) # Replace fc num_ftrs = self.model.fc.in_features if not dropout: @@ -116,18 +116,21 @@ def forward(self, x): def model_selection(modelname, num_out_classes, - dropout=None): + dropout=None, pretrained=True): """ :param modelname: :return: model, image size, pretraining, input_list """ if modelname == 'xception': return TransferModel(modelchoice='xception', - num_out_classes=num_out_classes), 299, \ - True, ['image'], None + num_out_classes=num_out_classes, + pretrained=pretrained), \ + 299, True, ['image'], None elif modelname == 'resnet18': - return TransferModel(modelchoice='resnet18', dropout=dropout, - num_out_classes=num_out_classes), \ + return TransferModel(modelchoice='resnet18', + dropout=dropout, + num_out_classes=num_out_classes, + pretrained=pretrained), \ 224, True, ['image'], None else: raise NotImplementedError(modelname)