Skip to content

Question about NormDistBase #1

@niklasnolte

Description

@niklasnolte

Hey,
this is a question more than an issue.

In your paper the sortnetlayer looks something like this: x_k = w.T @ sort(act(x + b_k)).
This should be implemented in NormDist(Base), right? I assume a simple MLP, no dropout or anything, so the first significant line in NormDistBase.forward is this:
https://github.com/zbh2047/SortNet/blob/main/core/models/norm_dist.py#L94
x = norm_dist(x, self.weight, self.p, self.groups, tag=self.tag)
and then the next thing (no mean shifting) is
x, lower, upper = apply_if_not_none((x, lower, upper), lambda z: z + bias.unsqueeze(-1))

In norm_dist (not using the custom cuda fn, because im on cpu):
y = input.view(input.size(0), groups, 1, -1, input.size(2)) - weight.view(groups, -1, weight.size(-1), 1)

thats the only time the weight seems to be used, and its not being used for multiplication.

also i don't find the sorting or biases that are of the right shape (from your equation in the paper it seems like the biases should have something like [fan_out, fan_in] as shape or so, but in NormDistBase i only find the neural network default [fan_out] bias.

I am clearly missing something, can you enlighten me how to write a full connected SortNet MLP?

My attempt was at a single layer was this:

class SortNetLayer(torch.nn.Module):
    def __init__(self, fan_in, fan_out, lip=1, kind="inf"):
        super().__init__()
        self.weight = torch.nn.Linear(fan_in, fan_out, bias=False).weight
        self.bias = torch.nn.Parameter(torch.randn(fan_out, fan_in))
        self.norm_fn = lambda x: "<a longer function that norms in the right way>"
        self.act_fn = lambda x : x
        self.sort_fn = lambda x: torch.sort(x)[0]

    def forward(self, x):
        x = x.unsqueeze(1) + self.bias
        x = self.act_fn(x)
        x = self.sort_fn(x)
        W = self.norm_fn(self.weight)
        return torch.einsum("oi,boi->bo", W, x)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions