diff --git a/layers.py b/layers.py index 1279187..c375552 100644 --- a/layers.py +++ b/layers.py @@ -20,11 +20,12 @@ class SinkhornDistance(nn.Module): - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` - Output: :math:`(N)` or :math:`()`, depending on `reduction` """ - def __init__(self, eps, max_iter, reduction='none'): + def __init__(self, eps, max_iter, device=None, reduction='none'): super(SinkhornDistance, self).__init__() self.eps = eps self.max_iter = max_iter self.reduction = reduction + self.device = device def forward(self, x, y): # The Sinkhorn algorithm takes as input three variables : @@ -38,12 +39,12 @@ def forward(self, x, y): # both marginals are fixed with equal weights mu = torch.empty(batch_size, x_points, dtype=torch.float, - requires_grad=False).fill_(1.0 / x_points).squeeze() + requires_grad=False, device=self.device).fill_(1.0 / x_points).squeeze() nu = torch.empty(batch_size, y_points, dtype=torch.float, - requires_grad=False).fill_(1.0 / y_points).squeeze() + requires_grad=False, device=self.device).fill_(1.0 / y_points).squeeze() - u = torch.zeros_like(mu) - v = torch.zeros_like(nu) + u = torch.zeros_like(mu, device=self.device) + v = torch.zeros_like(nu, device=self.device) # To check if algorithm terminates because of threshold # or max iterations reached actual_nits = 0