diff --git a/cuda_device.py b/cuda_device.py new file mode 100644 index 0000000..584ad32 --- /dev/null +++ b/cuda_device.py @@ -0,0 +1,3 @@ +import torch + +CUDA_DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu' diff --git a/demo.py b/demo.py index 505382a..d1e2bee 100644 --- a/demo.py +++ b/demo.py @@ -2,6 +2,7 @@ from networks.transforms import trimap_transform, groupnorm_normalise_image from networks.models import build_model from dataloader import PredDataset +from cuda_device import CUDA_DEVICE # System libs import os @@ -14,7 +15,10 @@ def np_to_torch(x): - return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda() + val = torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float() + if CUDA_DEVICE == 'gpu': + val = val.cuda() + return val def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray: diff --git a/networks/models.py b/networks/models.py index a0ea431..821f339 100644 --- a/networks/models.py +++ b/networks/models.py @@ -3,6 +3,7 @@ import networks.resnet_GN_WS as resnet_GN_WS import networks.layers_WS as L import networks.resnet_bn as resnet_bn +from cuda_device import CUDA_DEVICE def build_model(args): @@ -17,10 +18,11 @@ def build_model(args): model = MattingModule(net_encoder, net_decoder) - model.cuda() + if CUDA_DEVICE == 'gpu': + model.cuda() if(args.weights != 'default'): - sd = torch.load(args.weights) + sd = torch.load(args.weights, map_location=torch.device(CUDA_DEVICE)) model.load_state_dict(sd, strict=True) return model diff --git a/networks/transforms.py b/networks/transforms.py index 860c2b8..281e0aa 100644 --- a/networks/transforms.py +++ b/networks/transforms.py @@ -2,6 +2,7 @@ import numpy as np import torch import cv2 +from cuda_device import CUDA_DEVICE def dt(a): @@ -50,7 +51,9 @@ def groupnorm_denormalise_image(img, format='nhwc'): for i in range(3): img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i] else: - img1 = torch.zeros_like(img).cuda() + img1 = torch.zeros_like(img) + if CUDA_DEVICE == 'gpu': + img1 = img1.cuda() for i in range(3): img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i] return img1