From f3d21a81b2907143170e8721279cc43c36ff5acb Mon Sep 17 00:00:00 2001 From: aashraya Date: Wed, 15 Jul 2020 17:17:55 +0530 Subject: [PATCH 1/2] add changes to support CPU --- demo.py | 2 +- networks/models.py | 4 ++-- networks/transforms.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/demo.py b/demo.py index 505382a..1f6ae73 100644 --- a/demo.py +++ b/demo.py @@ -14,7 +14,7 @@ def np_to_torch(x): - return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda() + return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float() def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray: diff --git a/networks/models.py b/networks/models.py index a0ea431..5585ffe 100644 --- a/networks/models.py +++ b/networks/models.py @@ -17,10 +17,10 @@ def build_model(args): model = MattingModule(net_encoder, net_decoder) - model.cuda() + model if(args.weights != 'default'): - sd = torch.load(args.weights) + sd = torch.load(args.weights, map_location=torch.device('cpu')) model.load_state_dict(sd, strict=True) return model diff --git a/networks/transforms.py b/networks/transforms.py index 860c2b8..30403fb 100644 --- a/networks/transforms.py +++ b/networks/transforms.py @@ -50,7 +50,7 @@ def groupnorm_denormalise_image(img, format='nhwc'): for i in range(3): img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i] else: - img1 = torch.zeros_like(img).cuda() + img1 = torch.zeros_like(img) for i in range(3): img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i] return img1 From e6620c7e8f9429c633d5ed3d82841a5da36e6e99 Mon Sep 17 00:00:00 2001 From: aashraya Date: Wed, 15 Jul 2020 17:32:24 +0530 Subject: [PATCH 2/2] added flag to detect cuda device --- cuda_device.py | 3 +++ demo.py | 6 +++++- networks/models.py | 6 ++++-- networks/transforms.py | 3 +++ 4 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 cuda_device.py diff --git a/cuda_device.py b/cuda_device.py new file mode 100644 index 0000000..584ad32 --- /dev/null +++ b/cuda_device.py @@ -0,0 +1,3 @@ +import torch + +CUDA_DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu' diff --git a/demo.py b/demo.py index 1f6ae73..d1e2bee 100644 --- a/demo.py +++ b/demo.py @@ -2,6 +2,7 @@ from networks.transforms import trimap_transform, groupnorm_normalise_image from networks.models import build_model from dataloader import PredDataset +from cuda_device import CUDA_DEVICE # System libs import os @@ -14,7 +15,10 @@ def np_to_torch(x): - return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float() + val = torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float() + if CUDA_DEVICE == 'gpu': + val = val.cuda() + return val def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray: diff --git a/networks/models.py b/networks/models.py index 5585ffe..821f339 100644 --- a/networks/models.py +++ b/networks/models.py @@ -3,6 +3,7 @@ import networks.resnet_GN_WS as resnet_GN_WS import networks.layers_WS as L import networks.resnet_bn as resnet_bn +from cuda_device import CUDA_DEVICE def build_model(args): @@ -17,10 +18,11 @@ def build_model(args): model = MattingModule(net_encoder, net_decoder) - model + if CUDA_DEVICE == 'gpu': + model.cuda() if(args.weights != 'default'): - sd = torch.load(args.weights, map_location=torch.device('cpu')) + sd = torch.load(args.weights, map_location=torch.device(CUDA_DEVICE)) model.load_state_dict(sd, strict=True) return model diff --git a/networks/transforms.py b/networks/transforms.py index 30403fb..281e0aa 100644 --- a/networks/transforms.py +++ b/networks/transforms.py @@ -2,6 +2,7 @@ import numpy as np import torch import cv2 +from cuda_device import CUDA_DEVICE def dt(a): @@ -51,6 +52,8 @@ def groupnorm_denormalise_image(img, format='nhwc'): img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i] else: img1 = torch.zeros_like(img) + if CUDA_DEVICE == 'gpu': + img1 = img1.cuda() for i in range(3): img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i] return img1