Skip to content

Added device option to Wasserstein#3

Open
sAbhay wants to merge 1 commit intodfdazac:masterfrom
sAbhay:master
Open

Added device option to Wasserstein#3
sAbhay wants to merge 1 commit intodfdazac:masterfrom
sAbhay:master

Conversation

@sAbhay
Copy link

@sAbhay sAbhay commented Jun 20, 2022

Updated wasserstein submodule to include device

@dfdazac
Copy link
Owner

dfdazac commented Aug 25, 2022

@sAbhay thank you for your contribution! Sorry for the long delay on my response.
To keep compatibility with distributed training, where computations could run on different devices, I think it would be better to grab the device during the forward pass, rather than fixing it during initialization. For example,

def forward(self, x, y):
    device = x.device     
    ...

    mu = torch.empty(batch_size, x_points, dtype=torch.float,
                     requires_grad=False).fill_(1.0 / x_points).squeeze()
                     requires_grad=False, device=device).fill_(1.0 / x_points).squeeze()
    ...

This way, computations will run in whatever device x might be. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants