Skip to content

Conversation

@sup3rgiu
Copy link

Right now, torchinfo is using module.__class__.__name__ to retrieve the nn.Module name which will be shown in the summary. However, every PyTorch module exposes a method _get_name(), which in the default implementation simply returns self.__class__.__name__. However, a custom layer could overwrite this method to return a custom module name, avoiding to directly overwrite self.__class__.__name__ which is not a good idea in general.

Demo:

class CustomLayer(torch.nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        #self.__class__.__name__ = self._get_name()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input
    
    def _get_name(self) -> str:
        return "CustomFancyName"

img = torch.randn(1, 3, 224, 224)
model = CustomLayer()
summary(model, input_size=[img.shape], dtypes=[torch.float32], depth=2)

Output:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
CustomFancyName                         [1, 3, 224, 224]          --
==========================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.60
==========================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant