-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Thanks for the wonderful work!
I wonder if your project supports multi-gpu settings.
When I was trying to inference a batch of images on multi-gpu, it pops up some exceptions in StyleGAN2 related classes.
Initially, it raises tensor device mismatch issue for the following code.
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
and it points at style - truncation_latent, seems to claim that they are on different devices. I modified it to style.to(truncation_latent.device) - truncation_latent and successfully skipped this error. However, I am confused why this happened. Is it because of the for loop? Or is there somewhere a setting for it?
After this, I encountered several similar cases and fixed them by calling to(some_tensor.device). Then, the following occurs.
File "/projects/stylemask/libs/utilities/utils_inference.py", line 210, in invert_image
inverted_images, _ = generator([latent_codes], input_is_latent=True, return_latents = False, truncation= truncation, truncation_latent=trunc)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/stylemask/libs/models/StyleGAN2/model.py", line 519, in forward
out = self.conv1(out, latent[:, 0], noise=noise[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/stylemask/libs/models/StyleGAN2/model.py", line 332, in forward
out = self.conv(input, style)
^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/stylemask/libs/models/StyleGAN2/model.py", line 234, in forward
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/stylemask/libs/models/StyleGAN2/model.py", line 154, in forward
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
I figured that input and self.weight are on different devices. I have never encountered this because I did nn.DataParallel(), and am not sure why self.weight, which is an nn.Parameter(), is on a different device from the input.
Well, I had to try to fix them by continue doing things like self.weight.to(input.device). Finally, it reaches
mask_idx = self.mask_net[network_name_str](styles)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/projects/stylemask/libs/models/mask_predictor.py", line 23, in forward
out = self.masknet(input)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/pytorch-2.1.2/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
It turns out that the error happens in the mask_predictor, and in that class there is just a sequential network
self.masknet = nn.Sequential(nn.Linear(input_dim, inner_dim, bias=True),
nn.ReLU(),
nn.Linear(inner_dim, output_dim, bias=True),
)
I don't know why out = self.masknet(input) would raise the device mismatch error.
I had to omit some messages since I was inferencing StyleMask model within another project, hope the above messages explain my problems, and look forward to your reply.
Thanks!