-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Hey,
this is a question more than an issue.
In your paper the sortnetlayer looks something like this: x_k = w.T @ sort(act(x + b_k)).
This should be implemented in NormDist(Base), right? I assume a simple MLP, no dropout or anything, so the first significant line in NormDistBase.forward is this:
https://github.com/zbh2047/SortNet/blob/main/core/models/norm_dist.py#L94
x = norm_dist(x, self.weight, self.p, self.groups, tag=self.tag)
and then the next thing (no mean shifting) is
x, lower, upper = apply_if_not_none((x, lower, upper), lambda z: z + bias.unsqueeze(-1))
In norm_dist (not using the custom cuda fn, because im on cpu):
y = input.view(input.size(0), groups, 1, -1, input.size(2)) - weight.view(groups, -1, weight.size(-1), 1)
thats the only time the weight seems to be used, and its not being used for multiplication.
also i don't find the sorting or biases that are of the right shape (from your equation in the paper it seems like the biases should have something like [fan_out, fan_in] as shape or so, but in NormDistBase i only find the neural network default [fan_out] bias.
I am clearly missing something, can you enlighten me how to write a full connected SortNet MLP?
My attempt was at a single layer was this:
class SortNetLayer(torch.nn.Module):
def __init__(self, fan_in, fan_out, lip=1, kind="inf"):
super().__init__()
self.weight = torch.nn.Linear(fan_in, fan_out, bias=False).weight
self.bias = torch.nn.Parameter(torch.randn(fan_out, fan_in))
self.norm_fn = lambda x: "<a longer function that norms in the right way>"
self.act_fn = lambda x : x
self.sort_fn = lambda x: torch.sort(x)[0]
def forward(self, x):
x = x.unsqueeze(1) + self.bias
x = self.act_fn(x)
x = self.sort_fn(x)
W = self.norm_fn(self.weight)
return torch.einsum("oi,boi->bo", W, x)