diff --git a/.github/workflows/.pylintrc b/.github/workflows/.pylintrc index 12ccc22..65493e9 100644 --- a/.github/workflows/.pylintrc +++ b/.github/workflows/.pylintrc @@ -89,8 +89,8 @@ enable= E0601,E0602,E0603,E0604,E0611, E0632,E0633,E0701,E0702,E0703, E0704,E0710,E0711,E0712,E1003, - #E1102, - E1111,E1120,E1121,E1123, + ; #E1102, E1120, + E1111,E1121,E1123, E1124,E1125,E1126,E1127,E1128, E1129,E1130,E1131,E1132,E1133, E1134,E1135,E1136,E1137,E1138, diff --git a/chameleon/__init__.py b/chameleon/__init__.py index be1aed3..e377bf4 100644 --- a/chameleon/__init__.py +++ b/chameleon/__init__.py @@ -1,10 +1,6 @@ -from .backbone import * -from .efficientdet import * +from .base import * from .metrics import * -from .neck import * -from .nn import * -from .optim import * +from .modules import * from .tools import * -from .transformers import * __version__ = '0.1.0' diff --git a/chameleon/base/__init__.py b/chameleon/base/__init__.py new file mode 100644 index 0000000..7c8749e --- /dev/null +++ b/chameleon/base/__init__.py @@ -0,0 +1,8 @@ +from .blocks import build_block, list_blocks +from .components import build_component, list_components +from .layers import build_layer, list_layers +from .optim import (build_lr_scheduler, build_optimizer, list_lr_schedulers, + list_optimizers) +from .power_module import PowerModule +from .utils import (has_children, initialize_weights_, replace_module, + replace_module_attr_value) diff --git a/chameleon/base/blocks/__init__.py b/chameleon/base/blocks/__init__.py new file mode 100644 index 0000000..2eeedf6 --- /dev/null +++ b/chameleon/base/blocks/__init__.py @@ -0,0 +1,21 @@ +import fnmatch + +from .conv_block import Conv2dBlock, SeparableConv2dBlock + +# from .mamba_block import build_mamba_block +# from .vit_block import build_vit_block + + +def build_block(name, **kwargs): + cls = globals().get(name, None) + if cls is None: + raise ValueError(f'Block named {name} is not support.') + return cls(**kwargs) + + +def list_blocks(filter=''): + block_list = [k for k in globals().keys() if 'Block' in k] + if len(filter): + return fnmatch.filter(block_list, filter) # include these blocks + else: + return block_list diff --git a/chameleon/base/blocks/conv_block.py b/chameleon/base/blocks/conv_block.py new file mode 100644 index 0000000..e10b330 --- /dev/null +++ b/chameleon/base/blocks/conv_block.py @@ -0,0 +1,192 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..components import build_component +from ..power_module import PowerModule + + +class SeparableConv2dBlock(PowerModule): + + def __init__( + self, + in_channels: int, + out_channels: int = None, + kernel: Union[int, Tuple[int, int]] = 3, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 1, + bias: bool = False, + inner_norm: Optional[Union[dict, nn.Module]] = None, + inner_act: Optional[Union[dict, nn.Module]] = None, + norm: Optional[Union[dict, nn.Module]] = None, + act: Optional[Union[dict, nn.Module]] = None, + init_type: str = 'normal', + ): + """ + A separable convolution block consisting of a depthwise convolution and a pointwise convolution. + + Args: + in_channels (int): + Number of input channels. + out_channels (int, optional): + Number of output channels. If not provided, defaults to `in_channels`. + kernel (int or Tuple[int, int], optional): + Size of the convolution kernel. Defaults to 3. + stride (int or Tuple[int, int], optional): + Stride of the convolution. Defaults to 1. + padding (int or Tuple[int, int], optional): + Padding added to all four sides of the input. Defaults to 1. + bias (bool): + Whether to include a bias term in the convolutional layer. + Noted: if normalization layer is not None, bias will always be set to False. + Defaults to False. + inner_norm (dict or nn.Module, optional): + Configuration of normalization layer between dw and pw layer. Defaults to None. + inner_act (dict or nn.Module, optional): + Configuration of activation layer between dw and pw layer. Defaults to None. + norm (dict or nn.Module, optional): + Configuration of normalization layer after pw layer. Defaults to None. + act (dict or nn.Module, optional): + Configuration of activation layer after pw layer. Defaults to None. + init_type (str, optional): + Initialization method for the model parameters. Defaults to 'normal'. + """ + super().__init__() + out_channels = in_channels if out_channels is None else out_channels + + bias = False if norm is not None else bias + + self.block = nn.ModuleDict() + + self.block['dw_conv'] = nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel, + stride=stride, + padding=padding, + groups=in_channels, + bias=False, + ) + self.block['pw_conv'] = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + if inner_norm is not None: + self.block['inner_norm'] = build_component(**inner_norm) if isinstance(inner_norm, dict) else inner_norm + if inner_act is not None: + self.block['inner_act'] = build_component(**inner_act) if isinstance(inner_act, dict) else inner_act + if norm is not None: + self.block['norm'] = build_component(**norm) if isinstance(norm, dict) else norm + if act is not None: + self.block['act'] = build_component(**act) if isinstance(act, dict) else act + self.initialize_weights_(init_type) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for _, m in self.block.items(): + x = m(x) + return x + + +class Conv2dBlock(PowerModule): + + def __init__( + self, + in_channels: Union[float, int], + out_channels: Union[float, int], + kernel: Union[int, Tuple[int, int]] = 3, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = False, + padding_mode: str = 'zeros', + norm: Union[dict, nn.Module] = None, + act: Union[dict, nn.Module] = None, + init_type: str = 'normal', + ): + """ + This class is used to build a 2D convolutional neural network cell. + + Args: + in_channels (int or float): + Number of input channels. + out_channels (int or float): + Number of output channels. + kernel (int or tuple, optional): + Size of the convolutional kernel. Defaults to 3. + stride (int or tuple, optional): + Stride size. Defaults to 1. + padding (int or tuple, optional): + Padding size. Defaults to 1. + dilation (int, optional): + Spacing between kernel elements. Defaults to 1. + groups (int, optional): + Number of blocked connections from input channels to output + channels. Defaults to 1. + bias (bool, optional): + Whether to include a bias term in the convolutional layer. + If bias = None, bias would be set as Ture when normalization layer is None and + False when normalization layer is not None. + Defaults to None. + padding_mode (str, optional): + Options = {'zeros', 'reflect', 'replicate', 'circular'}. + Defaults to 'zeros'. + norm (Union[dict, nn.Module], optional): + normalization layer or a dictionary of arguments for building a + normalization layer. Default to None. + act (Union[dict, nn.Module], optional): + Activation function or a dictionary of arguments for building an + activation function. Default to None. + pool (Union[dict, nn.Module], optional): + pooling layer or a dictionary of arguments for building a pooling + layer. Default to None. + init_type (str): + Method for initializing model parameters. Default to 'normal'. + Options = {'normal', 'uniform'} + + Examples for using norm, act, and pool: + 1. conv_block = Conv2dBlock(in_channels=3, + out_channels=12, + norm=nn.BatchNorm2d(12), + act=nn.ReLU(), + pool=nn.AdaptiveAvgPool2d(1)) + 2. conv_block = Conv2dBlock(in_channels=3, + out_channels=12, + norm={'name': 'BatchNorm2d', 'num_features': 12}, + act={'name': 'ReLU', 'inplace': True}) + + Attributes: + block (nn.ModuleDict): a model block. + """ + super().__init__() + self.block = nn.ModuleDict() + + bias = False if norm is not None else bias + + self.block['conv'] = nn.Conv2d( + int(in_channels), + int(out_channels), + kernel_size=kernel, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + if norm is not None: + self.block['norm'] = build_component(**norm) if isinstance(norm, dict) else norm + if act is not None: + self.block['act'] = build_component(**act) if isinstance(act, dict) else act + + self.initialize_weights_(init_type) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for _, m in self.block.items(): + x = m(x) + return x diff --git a/chameleon/base/blocks/mamba_block.py b/chameleon/base/blocks/mamba_block.py new file mode 100644 index 0000000..dd9d074 --- /dev/null +++ b/chameleon/base/blocks/mamba_block.py @@ -0,0 +1,2 @@ +def build_mamba_block(**kwargs): + return print('to be implemented') diff --git a/chameleon/base/blocks/vit_block.py b/chameleon/base/blocks/vit_block.py new file mode 100644 index 0000000..2ced050 --- /dev/null +++ b/chameleon/base/blocks/vit_block.py @@ -0,0 +1,2 @@ +def build_vit_block(**kwargs): + return print('to be implemented') diff --git a/chameleon/base/components/__init__.py b/chameleon/base/components/__init__.py new file mode 100644 index 0000000..9502ffd --- /dev/null +++ b/chameleon/base/components/__init__.py @@ -0,0 +1,29 @@ +import fnmatch + +from .activation import * +from .dropout import * +from .loss import * +from .norm import * +from .pooling import * + + +def build_component_cls(name): + cls = globals().get(name, None) + if cls is None: + raise ValueError(f'Component named {name} is not support.') + return cls + + +def build_component(name, **options): + cls = globals().get(name, None) + if cls is None: + raise ValueError(f'Component named {name} is not support.') + return cls(**options) + + +def list_components(filter=''): + component_list = [k for k in globals().keys() if 'Component' in k] + if len(filter): + return fnmatch.filter(component_list, filter) # include these components + else: + return component_list diff --git a/chameleon/nn/components/activation.py b/chameleon/base/components/activation.py similarity index 91% rename from chameleon/nn/components/activation.py rename to chameleon/base/components/activation.py index eee225d..82368c1 100644 --- a/chameleon/nn/components/activation.py +++ b/chameleon/base/components/activation.py @@ -13,7 +13,7 @@ Tanhshrink, Threshold) __all__ = [ - 'Swish', 'Hsigmoid', 'Hswish', 'build_activation', 'StarReLU', 'SquaredReLU', + 'Swish', 'Hsigmoid', 'Hswish', 'StarReLU', 'SquaredReLU', ] __all__ += ['CELU', 'ELU', 'GELU', 'GLU', 'LeakyReLU', 'LogSigmoid', @@ -95,10 +95,3 @@ def forward(self, x): # Ref: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html Swish = nn.SiLU - - -def build_activation(name, **options) -> Union[nn.Module, None]: - cls = globals().get(name, None) - if cls is None: - raise ValueError(f'Activation named {name} is not supported.') - return cls(**options) diff --git a/chameleon/base/components/dropout.py b/chameleon/base/components/dropout.py new file mode 100644 index 0000000..6232ba0 --- /dev/null +++ b/chameleon/base/components/dropout.py @@ -0,0 +1,6 @@ +import torch.nn as nn +from torch.nn import AlphaDropout, Dropout, Dropout2d, Dropout3d + +__all__ = [ + 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', +] diff --git a/chameleon/nn/components/loss.py b/chameleon/base/components/loss.py similarity index 93% rename from chameleon/nn/components/loss.py rename to chameleon/base/components/loss.py index aac682b..61e48f9 100644 --- a/chameleon/nn/components/loss.py +++ b/chameleon/base/components/loss.py @@ -8,7 +8,7 @@ L1Loss, MSELoss, SmoothL1Loss) __all__ = [ - 'build_loss', 'AWingLoss', 'WeightedAWingLoss', + 'AWingLoss', 'WeightedAWingLoss', 'BCELoss', 'BCEWithLogitsLoss', 'CrossEntropyLoss', 'CTCLoss', 'KLDivLoss', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'ArcFace', 'CosFace', 'LogCoshDiceLoss', @@ -83,14 +83,6 @@ def forward(self, preds, targets, weight_map=None): return weighted.mean() -def build_loss(name: str, **options) -> Union[nn.Module, None]: - """Build a loss func layer given the name and options.""" - cls = globals().get(name, None) - if cls is None: - raise KeyError(f'Unsupported loss func: {name}') - return cls(**options) - - class ArcFace(nn.Module): def __init__(self, s=64.0, m=0.5): diff --git a/chameleon/nn/components/norm.py b/chameleon/base/components/norm.py similarity index 84% rename from chameleon/nn/components/norm.py rename to chameleon/base/components/norm.py index 5fa48e3..249e7f1 100644 --- a/chameleon/nn/components/norm.py +++ b/chameleon/base/components/norm.py @@ -13,7 +13,7 @@ __all__ = [ 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'CrossMapLRN2d', 'GroupNorm', 'LayerNorm', - 'LocalResponseNorm', 'build_norm', 'LayerNorm2d', + 'LocalResponseNorm', 'LayerNorm2d', ] @@ -42,11 +42,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x - - -def build_norm(name: str, **options) -> Union[nn.Module, None]: - cls = globals().get(name, None) - if cls is None: - raise ValueError( - f'Normalization named {name} is not supported. Available options: {__all__}') - return cls(**options) diff --git a/chameleon/nn/components/pooling.py b/chameleon/base/components/pooling.py similarity index 80% rename from chameleon/nn/components/pooling.py rename to chameleon/base/components/pooling.py index 1c73db1..f84723d 100644 --- a/chameleon/nn/components/pooling.py +++ b/chameleon/base/components/pooling.py @@ -9,7 +9,7 @@ MaxPool1d, MaxPool2d, MaxPool3d) __all__ = [ - 'build_pool', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', + 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'GAP', 'GMP', @@ -44,11 +44,3 @@ def __init__(self): def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply global max pooling on the input tensor.""" return self.pool(x) - - -def build_pool(name: str, **options) -> Union[nn.Module, None]: - """Build a pooling layer given the name and options.""" - cls = globals().get(name, None) - if cls is None: - raise KeyError(f'Unsupported pooling layer: {name}') - return cls(**options) diff --git a/chameleon/base/layers/__init__.py b/chameleon/base/layers/__init__.py new file mode 100644 index 0000000..1d8a493 --- /dev/null +++ b/chameleon/base/layers/__init__.py @@ -0,0 +1,22 @@ +import fnmatch + +from .aspp import ASPP +from .grl import GradientReversalLayer +from .selayer import SELayer +from .vae import VAE +from .weighted_sum import WeightedSum + + +def build_layer(name, **options): + cls = globals().get(name, None) + if cls is None: + raise ValueError(f'Layer named {name} is not support.') + return cls(**options) + + +def list_layers(filter=''): + layer_list = [k for k in globals().keys() if 'Layer' in k] + if len(filter): + return fnmatch.filter(layer_list, filter) # include these layers + else: + return layer_list diff --git a/chameleon/nn/aspp.py b/chameleon/base/layers/aspp.py similarity index 83% rename from chameleon/nn/aspp.py rename to chameleon/base/layers/aspp.py index 3e9773f..952a2ce 100644 --- a/chameleon/nn/aspp.py +++ b/chameleon/base/layers/aspp.py @@ -1,11 +1,13 @@ +from copy import deepcopy + import torch import torch.nn as nn -from .cnn import CNN2Dcell -from .components import Hswish -from .utils import PowerModule +from ..blocks import build_block +from ..components import Hswish +from ..power_module import PowerModule -__all__ = ['ASPPLayer'] +__all__ = ['ASPP'] __doc__ = """ REFERENCES: DeepLab: Semantic Image Segmentation with Deep Convolutional @@ -14,7 +16,7 @@ """ -class ASPPLayer(PowerModule): +class ASPP(PowerModule): ARCHS = { # ksize, stride, padding, dilation, is_use_hs @@ -31,7 +33,7 @@ def __init__( output_activate: nn.Module = nn.ReLU(), ): """ - Constructor for the ASPPLayer class. + Constructor for the ASPP class. Args: in_channels (int): @@ -45,7 +47,8 @@ def __init__( self.layers = nn.ModuleDict() for dilate_name, cfg in self.ARCHS.items(): ksize, stride, padding, dilation, use_hs = cfg - layer = CNN2Dcell( + layer = build_block( + 'Conv2dBlock', in_channels=in_channels, out_channels=in_channels, kernel=ksize, @@ -57,14 +60,15 @@ def __init__( ) self.layers[dilate_name] = layer - self.output_layer = CNN2Dcell( + self.output_layer = build_block( + 'Conv2dBlock', in_channels=in_channels * len(self.layers), out_channels=out_channels, kernel=1, stride=1, padding=0, norm=nn.BatchNorm2d(out_channels), - act=output_activate, + act=deepcopy(output_activate), ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/chameleon/nn/grl.py b/chameleon/base/layers/grl.py similarity index 96% rename from chameleon/nn/grl.py rename to chameleon/base/layers/grl.py index 22debd7..b328b17 100644 --- a/chameleon/nn/grl.py +++ b/chameleon/base/layers/grl.py @@ -1,7 +1,7 @@ import torch from torch.autograd import Function -from .utils import PowerModule +from ..power_module import PowerModule __all__ = ['GradientReversalLayer'] diff --git a/chameleon/nn/selayer.py b/chameleon/base/layers/selayer.py similarity index 62% rename from chameleon/nn/selayer.py rename to chameleon/base/layers/selayer.py index 87b2b3f..bbbd794 100644 --- a/chameleon/nn/selayer.py +++ b/chameleon/base/layers/selayer.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from .cnn import CNN2Dcell -from .utils import PowerModule +from ..blocks import build_block +from ..power_module import PowerModule __all__ = ['SELayer'] @@ -26,8 +26,24 @@ def __init__(self, in_channels: int, reduction: int = 4): mid_channels = max(1, in_channels // reduction) self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = CNN2Dcell(in_channels, mid_channels, kernel=1, stride=1, padding=0, act=nn.ReLU(False)) - self.fc2 = CNN2Dcell(mid_channels, in_channels, kernel=1, stride=1, padding=0, act=nn.Sigmoid()) + self.fc1 = build_block( + 'Conv2dBlock', + in_channels=in_channels, + out_channels=mid_channels, + kernel=1, + stride=1, + padding=0, + act=nn.ReLU(True) + ) + self.fc2 = build_block( + 'Conv2dBlock', + in_channels=mid_channels, + out_channels=in_channels, + kernel=1, + stride=1, + padding=0, + act=nn.Sigmoid(), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.avg_pool(x) diff --git a/chameleon/nn/vae.py b/chameleon/base/layers/vae.py similarity index 96% rename from chameleon/nn/vae.py rename to chameleon/base/layers/vae.py index aa59897..9396ee5 100644 --- a/chameleon/nn/vae.py +++ b/chameleon/base/layers/vae.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from .components import GAP -from .utils import PowerModule +from ..components import GAP +from ..power_module import PowerModule __all__ = ['VAE'] diff --git a/chameleon/base/layers/weighted_sum.py b/chameleon/base/layers/weighted_sum.py new file mode 100644 index 0000000..5154322 --- /dev/null +++ b/chameleon/base/layers/weighted_sum.py @@ -0,0 +1,52 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn + +from ..components import build_component + + +class WeightedSum(nn.Module): + + def __init__( + self, + input_size: int, + act: Optional[Union[dict, nn.Module]] = None, + requires_grad: bool = True, + ) -> None: + """ + Initializes a WeightedSum module. + + Args: + input_size (int): + The number of inputs to be summed. + act Optional[Union[dict, nn.Module]]: + Optional activation function or dictionary of its parameters. + Defaults to None. + requires_grad (bool, optional): + Whether to require gradients for the weights. Defaults to True. + """ + super().__init__() + self.input_size = input_size + self.weights = nn.Parameter( + torch.ones(input_size, dtype=torch.float32), + requires_grad=requires_grad + ) + self.weights_relu = nn.ReLU() + if act is None: + self.relu = nn.Identity() + else: + self.relu = act if isinstance(act, nn.Module) \ + else build_component(**act) + self.epsilon = 1e-4 + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + if len(x) != self.input_size: + raise ValueError('Invalid input size not equal to weight size.') + weights = self.weights_relu(self.weights) + weights = weights / ( + torch.sum(weights, dim=0, keepdim=True) + self.epsilon) + weighted_x = torch.einsum( + 'i,i...->...', weights, torch.stack(x, dim=0)) + weighted_x = self.relu(weighted_x) + return weighted_x diff --git a/chameleon/base/ops/__init__.py b/chameleon/base/ops/__init__.py new file mode 100644 index 0000000..334675b --- /dev/null +++ b/chameleon/base/ops/__init__.py @@ -0,0 +1 @@ +from .positional_encoding import * diff --git a/chameleon/nn/positional_encoding.py b/chameleon/base/ops/positional_encoding.py similarity index 100% rename from chameleon/nn/positional_encoding.py rename to chameleon/base/ops/positional_encoding.py diff --git a/chameleon/optim/__init__.py b/chameleon/base/optim/__init__.py similarity index 66% rename from chameleon/optim/__init__.py rename to chameleon/base/optim/__init__.py index b3e2e35..1b59c24 100644 --- a/chameleon/optim/__init__.py +++ b/chameleon/base/optim/__init__.py @@ -1,3 +1,5 @@ +import fnmatch + from torch.optim import (ASGD, LBFGS, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop, Rprop, SparseAdam) from torch.optim.lr_scheduler import (CosineAnnealingLR, @@ -21,3 +23,19 @@ def build_lr_scheduler(optimizer, name, **lr_scheduler_options): if cls_ is None: raise ValueError(f'{name} is not supported lr scheduler.') return cls_(optimizer, **lr_scheduler_options) + + +def list_optimizers(filter=''): + optimizer_list = [k for k in globals().keys() if k[0].isupper()] + if len(filter): + return [o for o in optimizer_list if filter in o.lower()] + else: + return optimizer_list + + +def list_lr_schedulers(filter=''): + lr_scheduler_list = [k for k in globals().keys() if 'LR' in k] + if len(filter): + return [o for o in lr_scheduler_list if filter in o.lower()] + else: + return lr_scheduler_list diff --git a/chameleon/optim/polynomial_lr_warmup.py b/chameleon/base/optim/polynomial_lr_warmup.py similarity index 100% rename from chameleon/optim/polynomial_lr_warmup.py rename to chameleon/base/optim/polynomial_lr_warmup.py diff --git a/chameleon/optim/warm_up.py b/chameleon/base/optim/warm_up.py similarity index 100% rename from chameleon/optim/warm_up.py rename to chameleon/base/optim/warm_up.py diff --git a/chameleon/base/power_module.py b/chameleon/base/power_module.py new file mode 100644 index 0000000..ff33d76 --- /dev/null +++ b/chameleon/base/power_module.py @@ -0,0 +1,77 @@ +from typing import List, Union + +import torch.nn as nn + +from .utils import initialize_weights_ + + +class PowerModule(nn.Module): + """ + A module that provides additional functionality for weight initialization, + freezing and melting layers. + """ + + def initialize_weights_(self, init_type: str = 'normal') -> None: + """ + Initialize the weights of the module. + + Args: + init_type (str): The type of initialization. Can be 'normal' or 'uniform'. + """ + initialize_weights_(self, init_type) + + def freeze(self, part_names: Union[str, List[str]] = 'all', verbose: bool = False) -> None: + """ + Freeze the parameters of specified layers. + + Args: + part_names (Union[str, List[str]]): The names of the layers to freeze. + If 'all', all layers are frozen. + verbose (bool): Whether to print messages indicating which layers were frozen. + """ + if part_names == 'all': + for name, params in self.named_parameters(): + if verbose: + print(f'Freezing layer {name}') + params.requires_grad_(False) + elif part_names is None: + return + else: + part_names = [part_names] if isinstance(part_names, str) \ + else part_names + for layer_name in part_names: + module = self + for attr in layer_name.split('.'): + module = getattr(module, attr) + for name, param in module.named_parameters(): + if verbose: + print(f'Freezing layer {layer_name}.{name}') + param.requires_grad_(False) + + def melt(self, part_names: Union[str, List[str]] = 'all', verbose: bool = False) -> None: + """ + Unfreeze the parameters of specified layers. + + Args: + part_names (Union[str, List[str]]): The names of the layers to unfreeze. + If 'all', all layers are unfrozen. + verbose (bool): Whether to print messages indicating which layers were unfrozen. + """ + if part_names == 'all': + for name, params in self.named_parameters(): + if verbose: + print(f'Unfreezing layer {name}') + params.requires_grad_(True) + elif part_names is None: + return + else: + part_names = [part_names] if isinstance(part_names, str) \ + else part_names + for layer_name in part_names: + module = self + for attr in layer_name.split('.'): + module = getattr(module, attr) + for name, param in module.named_parameters(): + if verbose: + print(f'Unfreezing layer {layer_name}.{name}') + param.requires_grad_(True) diff --git a/chameleon/base/utils.py b/chameleon/base/utils.py new file mode 100644 index 0000000..1e87c42 --- /dev/null +++ b/chameleon/base/utils.py @@ -0,0 +1,114 @@ +from typing import Any, Union + +import torch.nn as nn + +from .components import build_component, build_component_cls + +__all__ = ['has_children', 'replace_module', 'replace_module_attr_value', 'initialize_weights_'] + + +def has_children(module): + try: + next(module.children()) + return True + except StopIteration: + return False + + +def replace_module( + model: nn.Module, + target: Union[type, str], + dst_module: Union[nn.Module, dict] +) -> None: + """ + Function to replace modules. + + Args: + model (nn.Module): + NN module. + target (Union[type, str]): + The type of module you want to replace. + dst_module (Union[nn.Module, dict]): + The module you want to use after replacement. + """ + if not isinstance(dst_module, (nn.Module, dict)): + raise ValueError(f'dst_module = {dst_module} should be an instance of Module or dict.') + + target = build_component_cls(target) if isinstance(target, str) else target + dst_module = build_component(**dst_module) if isinstance(dst_module, dict) else dst_module + + for name, m in model.named_children(): + if has_children(m): + replace_module(m, target, dst_module) + else: + if isinstance(m, target): + setattr(model, name, dst_module) + + +def replace_module_attr_value( + model: nn.Module, + target: Union[type, str], + attr_name: str, + attr_value: Any +) -> None: + """ + Function to replace attr's value in target module + + Args: + model (nn.Module): NN module. + target (Union[type, str]): The type of module you want to modify. + attr_name (str): The name of the attribute you want to modify. + attr_value (Any): The new value of the attribute. + """ + target = build_component_cls(target) if isinstance(target, str) else target + for module in model.modules(): + if isinstance(module, target): + setattr(module, attr_name, attr_value) + + +def initialize_weights_( + model: nn.Module, + init_type: str = 'normal', +) -> None: + """ + Initialize the weights in the given model. + + Args: + model (nn.Module): + The model to initialize. + init_type (str, optional): + The initialization method to use. Supported options are 'uniform' + and 'normal'. Defaults to 'normal'. + + Raises: + TypeError: If init_type is not supported. + """ + if not isinstance(model, nn.Module): + raise TypeError( + f'model must be an instance of nn.Module, but got {type(model)}') + + init_functions = { + 'uniform': nn.init.kaiming_uniform_, + 'normal': nn.init.kaiming_normal_ + } + + if init_type not in init_functions: + raise TypeError(f'init_type {init_type} is not supported.') + nn_init = init_functions[init_type] + + def _recursive_init(m): + if has_children(m): + for child in m.children(): + _recursive_init(child) + else: + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn_init(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.affine: + nn.init.ones_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + _recursive_init(model) diff --git a/chameleon/efficientdet/__init__.py b/chameleon/efficientdet/__init__.py deleted file mode 100644 index 3a0503e..0000000 --- a/chameleon/efficientdet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .efficientdet import EfficientDet diff --git a/chameleon/efficientdet/efficientdet.py b/chameleon/efficientdet/efficientdet.py deleted file mode 100644 index ac508ba..0000000 --- a/chameleon/efficientdet/efficientdet.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import List - -import torch -from timm import create_model - -from ..neck import BiFPNs -from ..nn import PowerModule - -__all__ = ['EfficientDet'] - - -class EfficientDet(PowerModule): - - def __init__(self, compound_coef: int = 0, pretrained: bool = True, **kwargs): - """ - EfficientDet backbone. - - Args: - compound_coef (int, optional): - Compound scaling factor for the model architecture. Defaults to 0. - pretrained (bool, optional): - If True, returns a model pre-trained on ImageNet. Defaults to True. - """ - super().__init__() - self.compound_coef = compound_coef - - # Number of filters for each FPN layer at each compound coefficient - self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384] - - # Number of BiFPN repeats for each compound coefficient - self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8] - - # Number of channels for each input feature map at each compound coefficient - conv_channel_coef = { - # the channels of P3/P4/P5. - 0: [40, 112, 320], - 1: [40, 112, 320], - 2: [48, 120, 352], - 3: [48, 136, 384], - 4: [56, 160, 448], - 5: [64, 176, 512], - 6: [72, 200, 576], - 7: [80, 224, 640], - 8: [88, 248, 704], - } - - self.backbone = create_model( - f'efficientnet_b{compound_coef}', - pretrained=pretrained, - features_only=True, - exportable=True, - ) - - self.bifpn = BiFPNs( - in_channels_list=conv_channel_coef[compound_coef], - out_channels=self.fpn_num_filters[compound_coef], - n_bifpn=self.fpn_cell_repeats[compound_coef], - attention=True if compound_coef < 6 else False, - extra_layers=3 if compound_coef > 7 else 2, - **kwargs, - ) - - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - """ - Forward pass of the EfficientDet backbone. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). - - Returns: - List[torch.Tensor]: A list of feature maps, each with shape (batch_size, channels, height, width), - where the number of feature maps is equal to the number of BiFPN layers. - """ - return self.bifpn(self.backbone(x)[2:]) diff --git a/chameleon/modules/__init__.py b/chameleon/modules/__init__.py new file mode 100644 index 0000000..5167faf --- /dev/null +++ b/chameleon/modules/__init__.py @@ -0,0 +1,2 @@ +from .backbones import build_backbone, list_backbones +from .necks import build_neck, list_necks diff --git a/chameleon/backbone/__init__.py b/chameleon/modules/backbones/__init__.py similarity index 100% rename from chameleon/backbone/__init__.py rename to chameleon/modules/backbones/__init__.py diff --git a/chameleon/backbone/gpunet.py b/chameleon/modules/backbones/gpunet.py similarity index 98% rename from chameleon/backbone/gpunet.py rename to chameleon/modules/backbones/gpunet.py index 3c9f0ff..fad57f6 100644 --- a/chameleon/backbone/gpunet.py +++ b/chameleon/modules/backbones/gpunet.py @@ -3,8 +3,7 @@ import torch import torch.nn as nn -from ..nn import PowerModule -from ..tools import has_children +from ...base import PowerModule, has_children __all__ = ['GPUNet'] diff --git a/chameleon/neck/__init__.py b/chameleon/modules/necks/__init__.py similarity index 78% rename from chameleon/neck/__init__.py rename to chameleon/modules/necks/__init__.py index 081cfe9..fe48a90 100644 --- a/chameleon/neck/__init__.py +++ b/chameleon/modules/necks/__init__.py @@ -1,18 +1,14 @@ import fnmatch from .bifpn import BiFPN, BiFPNs -from .fpn import FPN, FPNs +from .fpn import FPN NECK = { 'fpn': FPN, - 'fpns': FPNs, 'bifpn': BiFPN, 'bifpns': BiFPNs, } -__all__ = [ - 'NECK', 'BiFPN', 'BiFPNs', 'FPN', 'FPNs', 'build_neck', 'list_necks', -] def build_neck(name: str, **kwargs): if name in NECK: diff --git a/chameleon/neck/bifpn.py b/chameleon/modules/necks/bifpn.py similarity index 91% rename from chameleon/neck/bifpn.py rename to chameleon/modules/necks/bifpn.py index def0ab6..6032bdd 100644 --- a/chameleon/neck/bifpn.py +++ b/chameleon/modules/necks/bifpn.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from ..nn import (CNN2Dcell, PowerModule, SeparableConvBlock, WeightedSum, - build_activation, build_norm) +from ...base import PowerModule, build_block, build_layer __all__ = ['BiFPN', 'BiFPNs'] @@ -76,7 +75,7 @@ def __init__( useful when input feature maps have a small spatial resolution. Defaults to 'bilinear'. use_conv (bool, optional): - In BiFPN, SeparableConvBlock is used by default to replace CNN. + In BiFPN, SeparableConv2dBlock is used by default to replace CNN. If you want to use a general CNN, set use_conv to True. Defaults to False. attention (bool, optional): @@ -98,17 +97,16 @@ def __init__( if extra_layers < 0: raise ValueError('extra_layers < 0, which is not invalid.') - conv2d = CNN2Dcell if use_conv else SeparableConvBlock - # Lateral layers conv1x1s = [] for i in range(num_out_features): in_channels = in_channels_list[i] if i < num_in_features else in_channels_list[-1] if in_channels != out_channels: conv1x1s.append( - CNN2Dcell( - in_channels, - out_channels, + build_block( + 'Conv2dBlock' if use_conv else 'SeparableConv2dBlock', + in_channels=in_channels, + out_channels=out_channels, kernel=1, stride=1, padding=0, @@ -120,9 +118,10 @@ def __init__( self.conv1x1s = nn.ModuleList(conv1x1s) self.conv_up_3x3s = nn.ModuleList([ - conv2d( - out_channels, - out_channels, + build_block( + 'Conv2dBlock' if use_conv else 'SeparableConv2dBlock', + in_channels=out_channels, + out_channels=out_channels, kernel=3, stride=1, padding=1, @@ -133,9 +132,10 @@ def __init__( ]) self.conv_down_3x3s = nn.ModuleList([ - conv2d( - out_channels, - out_channels, + build_block( + 'Conv2dBlock' if use_conv else 'SeparableConv2dBlock', + in_channels=out_channels, + out_channels=out_channels, kernel=3, stride=1, padding=1, @@ -147,14 +147,14 @@ def __init__( if extra_layers > 0: self.extra_conv_downs = nn.ModuleList([ - conv2d( - in_channels_list[-1], - in_channels_list[-1], + build_block( + 'Conv2dBlock' if use_conv else 'SeparableConv2dBlock', + in_channels=in_channels_list[-1], + out_channels=in_channels_list[-1], kernel=3, stride=2, padding=1, - norm=nn.BatchNorm2d( - in_channels_list[-1]) if norm is not None else None, + norm=nn.BatchNorm2d(in_channels_list[-1]) if norm is not None else None, act=deepcopy(act), ) for _ in range(extra_layers) @@ -180,12 +180,12 @@ def __init__( # Weight self.weighted_sum_2_input = nn.ModuleList([ - WeightedSum(2, act=nn.ReLU(False), requires_grad=attention) + build_layer('WeightedSum', input_size=2, act=nn.ReLU(False), requires_grad=attention) for _ in range(num_out_features) ]) self.weighted_sum_3_input = nn.ModuleList([ - WeightedSum(3, act=nn.ReLU(False), requires_grad=attention) + build_layer('WeightedSum', input_size=3, act=nn.ReLU(False), requires_grad=attention) for _ in range(num_out_features-2) ]) @@ -347,7 +347,7 @@ def __init__( A boolean flag indicating whether to use attention mechanism. Defaults to True. use_conv (bool, optional): - In BiFPN, SeparableConvBlock is used by default to replace CNN. + In BiFPN, SeparableConv2dBlock is used by default to replace CNN. If you want to use a general CNN, set use_conv to True. Defaults to False. diff --git a/chameleon/neck/fpn.py b/chameleon/modules/necks/fpn.py similarity index 71% rename from chameleon/neck/fpn.py rename to chameleon/modules/necks/fpn.py index 37630ca..1e4cf5a 100644 --- a/chameleon/neck/fpn.py +++ b/chameleon/modules/necks/fpn.py @@ -4,9 +4,9 @@ import torch import torch.nn as nn -from ..nn import CNN2Dcell, PowerModule, SeparableConvBlock +from ...base import PowerModule, build_block -__all__ = ['FPN', 'FPNs'] +__all__ = ['FPN'] class FPN(PowerModule): @@ -70,16 +70,15 @@ def __init__( if extra_layers < 0: raise ValueError('extra_layers < 0, which is not invalid.') - conv2d = SeparableConvBlock if use_dwconv else CNN2Dcell - self.conv1x1s = [] for i in range(num_out_features): in_channels = in_channels_list[i] if i < num_in_features else in_channels_list[-1] if in_channels != out_channels: self.conv1x1s.append( - CNN2Dcell( - in_channels, - out_channels, + build_block( + "Conv2dBlock" if not use_dwconv else "SeparableConv2dBlock", + in_channels=in_channels, + out_channels=out_channels, kernel=1, stride=1, padding=0, @@ -91,9 +90,10 @@ def __init__( self.conv1x1s = nn.ModuleList(self.conv1x1s) self.smooth3x3s = nn.ModuleList([ - conv2d( - out_channels, - out_channels, + build_block( + "Conv2dBlock" if not use_dwconv else "SeparableConv2dBlock", + in_channels=out_channels, + out_channels=out_channels, kernel=3, stride=1, padding=1, @@ -105,9 +105,10 @@ def __init__( if extra_layers > 0: self.extra_conv_downs = nn.ModuleList([ - conv2d( - in_channels_list[-1], - in_channels_list[-1], + build_block( + "Conv2dBlock" if not use_dwconv else "SeparableConv2dBlock", + in_channels=in_channels_list[-1], + out_channels=in_channels_list[-1], kernel=3, stride=2, padding=1, @@ -200,61 +201,3 @@ def build_fpn( upsample_mode=upsample_mode, use_dwconv=False, ) - - -class FPNs(PowerModule): - - def __init__( - self, - in_channels_list: List[int], - out_channels: int, - n_fpn: int, - extra_layers: int = 0, - out_indices: Optional[List[int]] = None, - upsample_mode: str = 'bilinear', - use_dwconv: bool = False, - ): - """ - Constructor of the FPN module. - - Args: - - in_channels_list (List[int]): - A list of integers representing the number of channels in each - input feature map. - out_channels (int): - The number of output channels for all feature maps. - n_fpn (int): - The number of FPN blocks to be stacked. - extra_layers (int, optional): - The number of extra down-sampling layers to add. Defaults to 0. - out_indices (Optional[List[int]], optional): - A list of integers indicating the indices of the feature maps to - output. If None, all feature maps are output. Defaults to None. - use_dwconv (bool, optional): - Whether to use depth-wise convolution in each Conv2d block. - Depth-wise convolution can reduce the number of parameters and - improve computation efficiency. However, it may also degrade the - quality of feature maps due to its low capacity. - Defaults to False. - - Raises: - ValueError: If the input `cls_method` is not supported. - """ - super().__init__() - cls_method = 'build_fpn' if not use_dwconv else 'build_dwfpn' - num_out_features = len(in_channels_list) + extra_layers - self.block = nn.ModuleList([ - getattr(FPN, cls_method)( - out_channels=out_channels, - in_channels_list=in_channels_list if i == 0 else [out_channels] * num_out_features, - extra_layers=extra_layers if i == 0 else 0, - out_indices=out_indices if i == n_fpn - 1 else None, - upsample_mode=upsample_mode, - ) for i in range(n_fpn) - ]) - - def forward(self, xs: List[torch.Tensor]) -> List[torch.Tensor]: - for fpn in self.block: - xs = fpn(xs) - return xs diff --git a/chameleon/modules/necks/pafpn.py b/chameleon/modules/necks/pafpn.py new file mode 100644 index 0000000..56f7002 --- /dev/null +++ b/chameleon/modules/necks/pafpn.py @@ -0,0 +1,94 @@ +# from typing import List, Optional, Union + +# import torch.nn as nn + +# from ..nn import CNN2Dcell, SeparableConvBlock +# from .fpn import FPN + + +# class PAFPN(FPN): + +# def __init__( +# self, +# in_channels_list: List[int], +# out_channels: int, +# extra_layers: int = 0, +# out_indices: Optional[List[int]] = None, +# norm: Optional[Union[dict, nn.Module]] = None, +# act: Optional[Union[dict, nn.Module]] = None, +# upsample_mode: str = 'bilinear', +# use_dwconv: bool = False, +# ): +# """ +# Feature Pyramid Network (FPN) module. + +# Args: +# in_channels_list (List[int]): +# A list of integers representing the number of channels in each +# input feature map. +# out_channels (int): +# The number of output channels for all feature maps. +# extra_layers (int, optional): +# The number of extra down-sampling layers to add. Defaults to 0. +# out_indices (Optional[List[int]], optional): +# A list of integers indicating the indices of the feature maps to +# output. If None, all feature maps are output. Defaults to None. +# norm Optional[Union[dict, nn.Module]]: +# Optional normalization module or dictionary of its parameters. +# Defaults to None. +# act Optional[Union[dict, nn.Module]]: +# Optional activation function or dictionary of its parameters. +# Defaults to None. +# upsample_mode (str, optional): +# The type of upsampling method to use, which can be 'bilinear' or +# 'nearest'. Bilinear upsampling is recommended in most cases for +# its better performance. Nearest neighbor upsampling may be useful +# when input feature maps have a small spatial resolution. +# Defaults to 'bilinear'. +# use_dwconv (bool, optional): +# Whether to use depth-wise convolution in each Conv2d block. +# Depth-wise convolution can reduce the number of parameters and +# improve computation efficiency. However, it may also degrade the +# quality of feature maps due to its low capacity. +# Defaults to False. + +# Raises: +# ValueError: If the number of input feature maps does not match the length of `in_channels_list`. +# Or if `extra_layers` is negative. +# """ +# super().__init__( +# in_channels_list=in_channels_list, +# out_channels=out_channels, +# extra_layers=extra_layers, +# out_indices=out_indices, +# norm=norm, +# act=act, +# upsample_mode=upsample_mode, +# use_dwconv=use_dwconv, +# ) +# conv2d = SeparableConvBlock if use_dwconv else CNN2Dcell + +# self.downsample_convs = nn.ModuleList() +# self.pafpn_convs = nn.ModuleList() +# for i in range(self.start_level + 1, self.backbone_end_level): +# d_conv = conv2d( +# out_channels, +# out_channels, +# kernel=3, +# stride=2, +# padding=1, +# norm=deepcopy(norm), +# act=deepcopy(act), +# ) +# pafpn_conv = ConvModule( +# out_channels, +# out_channels, +# 3, +# padding=1, +# conv_cfg=conv_cfg, +# norm_cfg=norm_cfg, +# act_cfg=act_cfg, +# inplace=False, +# ) +# self.downsample_convs.append(d_conv) +# self.pafpn_convs.append(pafpn_conv) diff --git a/chameleon/nn/__init__.py b/chameleon/nn/__init__.py deleted file mode 100644 index 6304b80..0000000 --- a/chameleon/nn/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from torch.nn import * - -from .aspp import * -from .block import * -from .cnn import * -from .components import * -from .dwcnn import * -from .grl import * -from .mbcnn import * -from .positional_encoding import * -from .selayer import * -from .utils import * -from .vae import * - - -def build_nn_cls(name): - cls_ = globals().get(name, None) - if cls_ is None: - raise ImportError(f'name {name} is not in nn.') - return cls_ - - -def build_nn(name, **kwargs): - return build_nn_cls(name)(**kwargs) diff --git a/chameleon/nn/block.py b/chameleon/nn/block.py deleted file mode 100644 index cf959f9..0000000 --- a/chameleon/nn/block.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .components import build_activation, build_norm -from .utils import PowerModule - -__all__ = ['SeparableConvBlock'] - - -class SeparableConvBlock(PowerModule): - - def __init__( - self, - in_channels: int, - out_channels: int = None, - kernel: Union[int, Tuple[int, int]] = 3, - stride: Union[int, Tuple[int, int]] = 1, - padding: Union[int, Tuple[int, int]] = 1, - bias: Optional[bool] = None, - norm: Optional[Union[dict, nn.Module]] = None, - act: Optional[Union[dict, nn.Module]] = None, - ): - """ - A separable convolution block consisting of a depthwise convolution and a pointwise convolution. - - Args: - in_channels (int): - Number of input channels. - out_channels (int, optional): - Number of output channels. If not provided, defaults to `in_channels`. - kernel (int or Tuple[int, int], optional): - Size of the convolution kernel. Defaults to 3. - stride (int or Tuple[int, int], optional): - Stride of the convolution. Defaults to 1. - padding (int or Tuple[int, int], optional): - Padding added to all four sides of the input. Defaults to 1. - bias (bool, optional): - Whether to include a bias term in the convolutional layer. - If bias = None, bias would be set as Ture when normalization layer is None and - False when normalization layer is not None. - Defaults to None. - norm (dict or nn.Module, optional): - Configuration of normalization layer. Defaults to None. - act (dict or nn.Module, optional): - Configuration of activation layer. Defaults to None. - """ - super().__init__() - out_channels = in_channels if out_channels is None else out_channels - - if bias is None: - bias = True if norm is None else False - - self.depthwise_conv = nn.Conv2d( - in_channels, - in_channels, - kernel_size=kernel, - stride=stride, - padding=padding, - groups=in_channels, - bias=bias, - ) - - self.pointwise_conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.norm = build_norm(**norm) if isinstance(norm, dict) else norm - self.act = build_activation(**act) if isinstance(act, dict) else act - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.depthwise_conv(x) - x = self.pointwise_conv(x) - if self.norm is not None: - x = self.norm(x) - if self.act is not None: - x = self.act(x) - return x diff --git a/chameleon/nn/cnn.py b/chameleon/nn/cnn.py deleted file mode 100644 index 62ea897..0000000 --- a/chameleon/nn/cnn.py +++ /dev/null @@ -1,123 +0,0 @@ -from collections import OrderedDict -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .components import build_activation, build_dropout, build_norm, build_pool -from .utils import PowerModule - -__all__ = [ - 'CNN2Dcell', -] - - -class CNN2Dcell(PowerModule): - - def __init__( - self, - in_channels: Union[float, int], - out_channels: Union[float, int], - kernel: Union[int, Tuple[int, int]] = 3, - stride: Union[int, Tuple[int, int]] = 1, - padding: Union[int, Tuple[int, int]] = 1, - dilation: int = 1, - groups: int = 1, - bias: Optional[bool] = None, - padding_mode: str = 'zeros', - norm: Union[dict, nn.Module] = None, - dropout: Union[dict, nn.Module] = None, - act: Union[dict, nn.Module] = None, - pool: Union[dict, nn.Module] = None, - init_type: str = 'normal', - ): - """ - This class is used to build a 2D convolutional neural network cell. - - Args: - in_channels (int or float): - Number of input channels. - out_channels (int or float): - Number of output channels. - kernel (int or tuple, optional): - Size of the convolutional kernel. Defaults to 3. - stride (int or tuple, optional): - Stride size. Defaults to 1. - padding (int or tuple, optional): - Padding size. Defaults to 1. - dilation (int, optional): - Spacing between kernel elements. Defaults to 1. - groups (int, optional): - Number of blocked connections from input channels to output - channels. Defaults to 1. - bias (bool, optional): - Whether to include a bias term in the convolutional layer. - If bias = None, bias would be set as Ture when normalization layer is None and - False when normalization layer is not None. - Defaults to None. - padding_mode (str, optional): - Options = {'zeros', 'reflect', 'replicate', 'circular'}. - Defaults to 'zeros'. - norm (Union[dict, nn.Module], optional): - normalization layer or a dictionary of arguments for building a - normalization layer. Default to None. - dropout (Union[dict, nn.Module], optional): - dropout layer or a dictionary of arguments for building a dropout - layer. Default to None. - act (Union[dict, nn.Module], optional): - Activation function or a dictionary of arguments for building an - activation function. Default to None. - pool (Union[dict, nn.Module], optional): - pooling layer or a dictionary of arguments for building a pooling - layer. Default to None. - init_type (str): - Method for initializing model parameters. Default to 'normal'. - Options = {'normal', 'uniform'} - - Examples for using norm, act, and pool: - 1. cell = CNN2Dcell(in_channels=3, - out_channels=12, - norm=nn.BatchNorm2d(12), - act=nn.ReLU(), - pool=nn.AdaptiveAvgPool2d(1)) - 2. cell = CNN2Dcell(in_channels=3, - out_channels=12, - norm={'name': 'BatchNorm2d', 'num_features': 12}, - act={'name': 'ReLU', 'inplace': True}) - - Attributes: - layer (nn.ModuleDict): a dictionary of layer contained in the cell. - """ - super().__init__() - self.layer = nn.ModuleDict() - - if bias is None: - bias = True if norm is None else False - - self.layer['cnn'] = nn.Conv2d( - int(in_channels), - int(out_channels), - kernel_size=kernel, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - ) - - optional_modules = OrderedDict({ - 'norm': build_norm(**norm) if isinstance(norm, dict) else norm, - 'dp': build_dropout(**dropout) if isinstance(dropout, dict) else dropout, - 'act': build_activation(**act) if isinstance(act, dict) else act, - 'pool': build_pool(**pool) if isinstance(pool, dict) else pool, - }) - for name, m in optional_modules.items(): - if m is not None: - self.layer[name] = m - self.initialize_weights_(init_type) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for _, m in self.layer.items(): - x = m(x) - return x diff --git a/chameleon/nn/components/__init__.py b/chameleon/nn/components/__init__.py deleted file mode 100644 index bfa7884..0000000 --- a/chameleon/nn/components/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .activation import * -from .dropout import * -from .loss import * -from .norm import * -from .pooling import * diff --git a/chameleon/nn/components/dropout.py b/chameleon/nn/components/dropout.py deleted file mode 100644 index ee427a4..0000000 --- a/chameleon/nn/components/dropout.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Union - -import torch.nn as nn -from torch.nn import AlphaDropout, Dropout, Dropout2d, Dropout3d - -__all__ = [ - 'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'build_dropout', -] - - -def build_dropout(name, **options) -> Union[nn.Module, None]: - cls = globals().get(name, None) - if cls is None: - raise ValueError(f'Dropout named {name} is not support.') - return cls(**options) diff --git a/chameleon/nn/dwcnn.py b/chameleon/nn/dwcnn.py deleted file mode 100644 index 78d2bb1..0000000 --- a/chameleon/nn/dwcnn.py +++ /dev/null @@ -1,41 +0,0 @@ -from collections import OrderedDict - -import torch.nn as nn - -__all__ = ['depth_conv2d', 'conv_dw', 'conv_dw_in'] - - -def depth_conv2d(in_channels: int, out_channels: int, kernel: int = 1, stride: int = 1, pad: int = 0): - return nn.Sequential( - OrderedDict([ - ('conv3x3', nn.Conv2d(in_channels, in_channels, kernel_size=kernel, stride=stride, padding=pad, groups=in_channels),), - ('act', nn.ReLU(),), - ('conv1x1', nn.Conv2d(in_channels, out_channels, kernel_size=1)), - ]) - ) - - -def conv_dw(in_channels: int, out_channels: int, stride: int, act: nn.Module = nn.ReLU()): - return nn.Sequential( - OrderedDict([ - ('conv3x3', nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False)), - ('bn1', nn.BatchNorm2d(in_channels)), - ('act1', nn.ReLU()), - ('conv1x1', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)), - ('bn2', nn.BatchNorm2d(out_channels)), - ('act2', act), - ]) - ) - - -def conv_dw_in(in_channels: int, out_channels: int, stride: int, act: nn.Module = nn.ReLU()): - return nn.Sequential( - OrderedDict([ - ('conv3x3', nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False)), - ('in1', nn.InstanceNorm2d(in_channels)), - ('act1', nn.ReLU()), - ('conv1x1', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)), - ('in2', nn.InstanceNorm2d(out_channels)), - ('act2', act), - ]) - ) diff --git a/chameleon/nn/mbcnn.py b/chameleon/nn/mbcnn.py deleted file mode 100644 index eba63a6..0000000 --- a/chameleon/nn/mbcnn.py +++ /dev/null @@ -1,211 +0,0 @@ -from copy import deepcopy -from typing import Tuple, Union - -import torch -import torch.nn as nn - -from .cnn import CNN2Dcell -from .components import build_norm -from .selayer import SELayer -from .utils import PowerModule - -__all__ = ['MBCNNcell'] - - -class MBCNNcell(PowerModule): - - def __init__( - self, - in_channels: int, - out_channels: int, - hid_channels: int = None, - kernel: Union[int, Tuple[int, int]] = 3, - stride: Union[int, Tuple[int, int]] = 1, - use_se: bool = False, - se_reductioin: int = 4, - inner_norm: Union[dict, nn.Module] = None, - inner_act: Union[dict, nn.Module] = None, - norm: Union[dict, nn.Module] = None, - ): - """ - This neural network block is commonly known as the "inverted residual block", - which is used in MobileNetV2, MobileNetV3, and EfficientNet (but not always). - ref: https://arxiv.org/pdf/1905.02244.pdf - - For MobileNetV1, the block consists of a kxk depth-wise convolution with - group normalization, batch normalization, and ReLU activation, followed - by a 1x1 projection with batch normalization. - - mbv1: - input ---> kxk depth-wise (group, bn, relu) ---> 1x1 projection (bn) - - For MobileNetV2, the block starts with a 1x1 expansion with batch normalization - and ReLU6 activation, followed by a kxk depth-wise convolution with group - normalization, batch normalization, and ReLU6 activation, and ends with a - 1x1 projection with batch normalization. - - mbv2: - input ---> 1x1 expansion (bn, relu6) ---> kxk depth-wise (group, bn, relu6) ---> 1x1 projection (bn) - - For MobileNetV3, the block starts with a 1x1 expansion with batch normalization - and h-swish activation, followed by a kxk depth-wise convolution with group - normalization, batch normalization, and h-swish activation, and ends with a - 1x1 projection with batch normalization. In addition, MobileNetV3 uses a - squeeze-and-excitation (SE) layer to enhance feature interdependencies. - - mbv3: - input ---> 1x1 expansion (bn, hswish) ---> kxk depth-wise (group, bn, hswish) ---> 1x1 projection (bn) - | ↑ - ↓----------> SE layer (v3) -------->| - - - Args: - in_channels (int): - The number of input channels. - hid_channels (int): - The number of hidden channels for expanding dimensions. - out_channels (int): - The number of output channels. - kernel (Union[int, Tuple[int, int]], optional): - The kernel size of the depth-wise convolution. Defaults to 3. - stride (int, optional): - The stride size of the depth-wise convolution. Defaults to 1. - use_se (bool, optional): - Whether to use the SE layer. Defaults to True. - se_reduction (int, optional): - Reduction ratio for the number of hidden channels in the SE layer. - Defaults to 4. - inner_norm (Union[dict, nn.Module], optional): - Dictionary or function that creates a normalization layer inside - the MB block. Defaults to None. - inner_act (Union[dict, nn.Module], optional): - Dictionary or function that creates an activation layer inside - the MB block. Defaults to None. - norm (Union[dict, nn.Module], optional): - Dictionary or function that creates a normalization layer on the - last stage. Defaults to None. - """ - super().__init__() - self.identity = stride == 1 and in_channels == out_channels - - if hid_channels is None: - hid_channels = in_channels - - if hid_channels != in_channels: - self.expdim = CNN2Dcell( - in_channels, - hid_channels, - kernel=1, - stride=1, - padding=0, - norm=deepcopy(inner_norm), - act=deepcopy(inner_act), - ) - - padding = (kernel - 1) // 2 if isinstance(kernel, int) else \ - ((kernel[0] - 1) // 2, (kernel[1] - 1) // 2) - - self.dwise = CNN2Dcell( - hid_channels, - hid_channels, - kernel=kernel, - stride=stride, - padding=padding, - groups=hid_channels, - norm=deepcopy(inner_norm), - act=deepcopy(inner_act), - ) - - if use_se: - self.dwise_se = SELayer( - hid_channels, - se_reductioin, - ) - - self.pwise_linear = CNN2Dcell( - hid_channels, - out_channels, - kernel=1, - stride=1, - padding=0, - ) - - if norm is not None: - self.norm = norm if isinstance(norm, nn.Module) else build_norm(**norm) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = x - if hasattr(self, 'expdim'): - out = self.expdim(out) - out = self.dwise(out) - if hasattr(self, 'dwise_se'): - out = self.dwise_se(out) - out = self.pwise_linear(out) - if hasattr(self, 'norm'): - out = self.norm(out) - out = x + out if self.identity else out # skip connection - return out - - @classmethod - def build_mbv1block( - cls, - in_channels: int, - out_channels: int, - kernel: Union[int, Tuple[int, int]] = 3, - stride: Union[int, Tuple[int, int]] = 1, - ): - return cls( - in_channels=in_channels, - out_channels=out_channels, - hid_channels=in_channels, - kernel=kernel, - stride=stride, - use_se=False, - inner_norm=nn.BatchNorm2d(in_channels), - inner_act=nn.ReLU(False), - norm=nn.BatchNorm2d(out_channels), - ) - - @classmethod - def build_mbv2block( - cls, - in_channels: int, - out_channels: int, - expand_ratio: float = 2, - kernel: Union[int, Tuple[int, int]] = 3, - stride: Union[int, Tuple[int, int]] = 1, - ): - hid_channels = int(in_channels * expand_ratio) - return cls( - in_channels=in_channels, - out_channels=out_channels, - hid_channels=hid_channels, - kernel=kernel, - stride=stride, - use_se=False, - inner_norm=nn.BatchNorm2d(hid_channels), - inner_act=nn.ReLU6(False), - norm=nn.BatchNorm2d(out_channels), - ) - - @classmethod - def build_mbv3block( - cls, - in_channels: int, - out_channels: int, - expand_ratio: float = 2, - kernel: Union[int, Tuple[int, int]] = 3, - stride: Union[int, Tuple[int, int]] = 1, - ): - hid_channels = int(in_channels * expand_ratio) - return cls( - in_channels=in_channels, - out_channels=out_channels, - hid_channels=hid_channels, - kernel=kernel, - stride=stride, - use_se=True, - inner_norm=nn.BatchNorm2d(hid_channels), - inner_act=nn.Hardswish(False), - norm=nn.BatchNorm2d(out_channels), - ) diff --git a/chameleon/nn/utils.py b/chameleon/nn/utils.py deleted file mode 100644 index c7d7b70..0000000 --- a/chameleon/nn/utils.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Any, List, Optional, Union - -import torch -import torch.nn as nn - -from .components import build_activation - -__all__ = [ - 'PowerModule', 'initialize_weights', 'WeightedSum', 'Identity', - 'Transpose', 'Permute', -] - - -def initialize_weights( - model: nn.Module, - init_type: str = 'normal', - recursive: bool = True -) -> None: - """ - Initialize the weights in the given model. - - Args: - model (nn.Module): - The model to initialize. - init_type (str, optional): - The initialization method to use. Supported options are 'uniform' - and 'normal'. Defaults to 'normal'. - recursive (bool, optional): - Whether to recursively initialize child modules. Defaults to True. - - Raises: - TypeError: If init_type is not supported. - """ - if not isinstance(model, nn.Module): - raise TypeError( - f'model must be an instance of nn.Module, but got {type(model)}') - - init_functions = { - 'uniform': nn.init.kaiming_uniform_, - 'normal': nn.init.kaiming_normal_ - } - - if init_type not in init_functions: - raise TypeError(f'init_type {init_type} is not supported.') - nn_init = init_functions[init_type] - - def _recursive_init(m): - for child in m.children(): - if len(list(child.children())) > 0 and recursive: - _recursive_init(child) - else: - if isinstance(child, (nn.Conv2d, nn.Linear)): - nn_init(child.weight) - if child.bias is not None: - nn.init.zeros_(child.bias) - elif isinstance(child, (nn.BatchNorm1d, nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if child.affine: - nn.init.ones_(child.weight) - if child.bias is not None: - nn.init.zeros_(child.bias) - - _recursive_init(model) - - -class PowerModule(nn.Module): - """ - A module that provides additional functionality for weight initialization, - freezing and melting layers. - """ - - def initialize_weights_(self, init_type: str = 'normal') -> None: - """ - Initialize the weights of the module. - - Args: - init_type (str): The type of initialization. Can be 'normal' or 'uniform'. - """ - initialize_weights(self, init_type) - - def freeze(self, part_names: Union[str, List[str]] = 'all', verbose: bool = False) -> None: - """ - Freeze the parameters of specified layers. - - Args: - part_names (Union[str, List[str]]): The names of the layers to freeze. - If 'all', all layers are frozen. - verbose (bool): Whether to print messages indicating which layers were frozen. - """ - if part_names == 'all': - for name, params in self.named_parameters(): - if verbose: - print(f'Freezing layer {name}') - params.requires_grad_(False) - elif part_names is None: - return - else: - part_names = [part_names] if isinstance(part_names, str) \ - else part_names - for layer_name in part_names: - module = self - for attr in layer_name.split('.'): - module = getattr(module, attr) - for name, param in module.named_parameters(): - if verbose: - print(f'Freezing layer {layer_name}.{name}') - param.requires_grad_(False) - - def melt(self, part_names: Union[str, List[str]] = 'all', verbose: bool = False) -> None: - """ - Unfreeze the parameters of specified layers. - - Args: - part_names (Union[str, List[str]]): The names of the layers to unfreeze. - If 'all', all layers are unfrozen. - verbose (bool): Whether to print messages indicating which layers were unfrozen. - """ - if part_names == 'all': - for name, params in self.named_parameters(): - if verbose: - print(f'Unfreezing layer {name}') - params.requires_grad_(True) - elif part_names is None: - return - else: - part_names = [part_names] if isinstance(part_names, str) \ - else part_names - for layer_name in part_names: - module = self - for attr in layer_name.split('.'): - module = getattr(module, attr) - for name, param in module.named_parameters(): - if verbose: - print(f'Unfreezing layer {layer_name}.{name}') - param.requires_grad_(True) - - -class WeightedSum(nn.Module): - - def __init__( - self, - input_size: int, - act: Optional[Union[dict, nn.Module]] = None, - requires_grad: bool = True, - ) -> None: - """ - Initializes a WeightedSum module. - - Args: - input_size (int): - The number of inputs to be summed. - act Optional[Union[dict, nn.Module]]: - Optional activation function or dictionary of its parameters. - Defaults to None. - requires_grad (bool, optional): - Whether to require gradients for the weights. Defaults to True. - """ - super().__init__() - self.input_size = input_size - self.weights = nn.Parameter( - torch.ones(input_size, dtype=torch.float32), - requires_grad=requires_grad - ) - self.weights_relu = nn.ReLU() - if act is None: - self.relu = nn.Identity() - else: - self.relu = act if isinstance(act, nn.Module) \ - else build_activation(**act) - self.epsilon = 1e-4 - - def forward(self, x: List[torch.Tensor]) -> torch.Tensor: - if len(x) != self.input_size: - raise ValueError('Invalid input size not equal to weight size.') - weights = self.weights_relu(self.weights) - weights = weights / ( - torch.sum(weights, dim=0, keepdim=True) + self.epsilon) - weighted_x = torch.einsum( - 'i,i...->...', weights, torch.stack(x, dim=0)) - weighted_x = self.relu(weighted_x) - return weighted_x - - -class Identity(PowerModule): - r"""A placeholder identity operator that is argument-insensitive. - - Args: - args: any argument (unused) - kwargs: any keyword argument (unused) - - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. - - Examples:: - - >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) - >>> input = torch.randn(128, 20) - >>> output = m(input) - >>> print(output.size()) - torch.Size([128, 20]) - - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return input - - -class Transpose(nn.Module): - - def __init__(self, dim1: int, dim2: int) -> None: - super().__init__() - self.dim1 = dim1 - self.dim2 = dim2 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.transpose(self.dim1, self.dim2) - - -class Permute(nn.Module): - - def __init__(self, dims: List[int]) -> None: - super().__init__() - self.dims = dims - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.permute(*self.dims) diff --git a/chameleon/tools/__init__.py b/chameleon/tools/__init__.py index b04de5e..fdfeba1 100644 --- a/chameleon/tools/__init__.py +++ b/chameleon/tools/__init__.py @@ -1,4 +1 @@ -from .custom_aug import * -from .mixin import * -from .model_profile import * -from .replace import * +from .calflops import calculate_flops diff --git a/chameleon/tools/calflops/__init__.py b/chameleon/tools/calflops/__init__.py new file mode 100644 index 0000000..6f73fd9 --- /dev/null +++ b/chameleon/tools/calflops/__init__.py @@ -0,0 +1,18 @@ +# !usr/bin/env python +# -*- coding:utf-8 -*- + +''' + Description : + Version : 1.0 + Author : MrYXJ + Mail : yxj2017@gmail.com + Github : https://github.com/MrYxJ + Date : 2023-08-19 10:27:55 + LastEditTime : 2023-09-05 15:31:43 + Copyright (C) 2023 mryxj. All rights reserved. +''' + +from .flops_counter import calculate_flops +from .utils import (bytes_to_string, flops_to_string, + generate_transformer_input, macs_to_string, + number_to_string, params_to_string) diff --git a/chameleon/tools/calflops/calculate_pipline.py b/chameleon/tools/calflops/calculate_pipline.py new file mode 100644 index 0000000..454dba9 --- /dev/null +++ b/chameleon/tools/calflops/calculate_pipline.py @@ -0,0 +1,378 @@ +# !usr/bin/env python +# -*- coding:utf-8 -*- + +''' + Description : + Version : 1.0 + Author : MrYXJ + Mail : yxj2017@gmail.com + Github : https://github.com/MrYxJ + Date : 2023-08-20 11:04:11 + LastEditTime : 2023-09-08 23:42:00 + Copyright (C) 2023 mryxj. All rights reserved. +''' + +''' +The part of code is inspired by ptflops and deepspeed profiling. +''' + +from functools import partial + +from .pytorch_ops import (MODULE_HOOK_MAPPING, _patch_functionals, + _patch_tensor_methods, _reload_functionals, + _reload_tensor_methods) +from .utils import (flops_to_string, get_module_flops, get_module_macs, + macs_to_string, number_to_string, params_to_string) + +DEFAULT_PRECISION = 2 +module_flop_count = [] +module_mac_count = [] +old_functions = {} + + +class CalFlopsPipline(object): + """The Pipline of calculating FLOPs(number of estimated floating-point operations) and Parameters of each module in a PyTorch model. + The pipline is calculating the forward(and alson include back propagation) pass of a PyTorch model and prints the model graph with the calculated static attached to each module. + It can easily get only final resulst of FLOPs about model, and also can be showed how flops and parameters are spent in the model and which modules or layers could be the bottleneck in detailed. + """ + + def __init__(self, model, include_backPropagation, compute_bp_factor, is_sparse): + """Init Pipline of Calculating the FLOPs about model. + + Args: + model (pytorch model): The model must be a pytorh model now. + compute_fwd_factor (float): Defaults to 2.0. According to https://epochai.org/blog/backward-forward-FLOP-ratio + """ + + self.model = model + self.include_backPropagation = include_backPropagation # Whether the calculation results include model backpropagation + self.compute_bp_factor = compute_bp_factor # Backpropagation takes twice as much computation as forward propagation + self.pipline_started = False # The flag of calculating FLOPs pipline started + self.func_patched = False # The flag of wheather calculating functional are patched + self.is_sparse = is_sparse # Whether to exclude sparse matrix flops + + def start_flops_calculate(self, ignore_list=None): + """Starts the pipline of calculating FLOPs. + + Extra attributes are added recursively to all the modules and the calculate torch.nn.functionals are monkey patched. + + Args: + ignore_list (list, optional): the list of modules to ignore while Piplining. Defaults to None. + """ + + self.reset_flops_calculate() + _patch_functionals(old_functions, module_flop_count, module_mac_count) + _patch_tensor_methods(old_functions, module_flop_count, module_mac_count) + + def register_module_hooks(module, ignore_list): + if ignore_list and type(module) in ignore_list: + return + + # if computing the flops of a module directly + if type(module) in MODULE_HOOK_MAPPING: + if not hasattr(module, "__flops_handle__"): + module.__flops_handle__ = module.register_forward_hook(MODULE_HOOK_MAPPING[type(module)]) + return + + # if computing the flops of the functionals in a module + def pre_hook(module, input): + module_flop_count.append([]) + module_mac_count.append([]) + + if not hasattr(module, "__pre_hook_handle__"): + module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) + + def post_hook(module, input, output): + if module_flop_count: + module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]]) + module_flop_count.pop() + module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]]) + module_mac_count.pop() + + if not hasattr(module, "__post_hook_handle__"): + module.__post_hook_handle__ = module.register_forward_hook(post_hook) + + self.model.apply(partial(register_module_hooks, ignore_list=ignore_list)) + self.pipline_started = True + self.func_patched = True + + def stop_flops_calculate(self): + """Stop the pipline of calculating FLOPs. + + All torch.nn.functionals are restored to their originals. + """ + if self.pipline_started and self.func_patched: + _reload_functionals(old_functions) + _reload_tensor_methods(old_functions) + self.func_patched = False + + def remove_calculate_attrs(module): + if hasattr(module, "__pre_hook_handle__"): + module.__pre_hook_handle__.remove() + del module.__pre_hook_handle__ + if hasattr(module, "__post_hook_handle__"): + module.__post_hook_handle__.remove() + del module.__post_hook_handle__ + if hasattr(module, "__flops_handle__"): + module.__flops_handle__.remove() + del module.__flops_handle__ + + self.model.apply(remove_calculate_attrs) + + def reset_flops_calculate(self): + """Resets the pipline of calculating FLOPs. + + Adds or resets the extra attributes, include flops、macs、params. + """ + + def add_or_reset_attrs(module): + module.__flops__ = 0 + module.__macs__ = 0 + module.__params__ = sum( + p.count_nonzero().item() for p in module.parameters() if p.requires_grad + ) if self.is_sparse else sum( + p.numel() for p in module.parameters() if p.requires_grad) + # just calculate parameter need training. + + self.model.apply(add_or_reset_attrs) + + def end_flops_calculate(self): + """Ends the pipline of calculating FLOPs. + + The added attributes and handles are removed recursively on all the modules. + """ + if not self.pipline_started: + return + self.stop_flops_calculate() + self.pipline_started = False + + def remove_calculate_attrs(module): + if hasattr(module, "__flops__"): + del module.__flops__ + if hasattr(module, "__macs__"): + del module.__macs__ + if hasattr(module, "__params__"): + del module.__params__ + + self.model.apply(remove_calculate_attrs) + + def get_total_flops(self, as_string=False): + """Returns the total flops of the model. + + Args: + as_string (bool, optional): whether to output the flops as string. Defaults to False. + + Returns: + The number of multiply-accumulate operations of the model forward pass. + """ + total_flops = get_module_flops(self.model, is_sparse=self.is_sparse) + return number_to_string(total_flops) if as_string else total_flops + + def get_total_macs(self, as_string=False): + """Returns the total MACs of the model. + + Args: + as_string (bool, optional): whether to output the flops as string. Defaults to False. + + Returns: + The number of multiply-accumulate operations of the model forward pass. + """ + total_macs = get_module_macs(self.model, is_sparse=self.is_sparse) + return macs_to_string(total_macs) if as_string else total_macs + + def get_total_params(self, as_string=False): + """Returns the total number of parameters stored per rank. + + Args: + as_string (bool, optional): whether to output the parameters as string. Defaults to False. + is_sparse (bool, optional): whether to output the parameters as string. Defaults to False. + + Returns: + The total number of parameters stored per rank. + """ + total_params = self.model.__params__ + return params_to_string(total_params) if as_string else total_params + + def print_return_model_pipline(self, units=None, precision=DEFAULT_PRECISION, print_detailed=True, + print_results=True): + """Prints the model graph with the calculateing pipline attached to each module. + + Args: + module_depth (int, optional): The depth of the model to which to print the aggregated module information. When set to -1, it prints information from the top to the innermost modules (the maximum depth). + top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified. + print_detailed (bool, optional): Whether to print the detailed model profile. + """ + if not self.pipline_started: + return + + total_flops = self.get_total_flops() + total_macs = self.get_total_macs() + total_params = self.get_total_params() + + self.flops = total_flops + self.macs = total_macs + self.params = total_params + + prints = [] + prints.append( + "\n------------------------------------- Calculate Flops Results -------------------------------------") + + prints.append("Notations:\n" + + "number of parameters (Params), number of multiply-accumulate operations(MACs),\n" + + "number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),\n" + + "fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),\n" + + "default model backpropagation takes %.2f times as much computation as forward propagation.\n" % self.compute_bp_factor) + + line_fmt = '{:<70} {:<8}' + prints.append(line_fmt.format('Total Training Params: ', params_to_string(total_params))) + + prints.append(line_fmt.format('fwd MACs: ', macs_to_string(total_macs, units=units, + precision=precision))) + prints.append(line_fmt.format('fwd FLOPs: ', flops_to_string(total_flops, units=units, + precision=precision))) + prints.append(line_fmt.format('fwd+bwd MACs: ', macs_to_string(total_macs * (1 + self.compute_bp_factor), + units=units, precision=precision))) + prints.append(line_fmt.format('fwd+bwd FLOPs: ', flops_to_string(total_flops * (1 + self.compute_bp_factor), + units=units, precision=precision))) + + def flops_repr(module): + params = module.__params__ + flops = get_module_flops(module) + macs = get_module_macs(module) + items = [ + "{} = {:g}% Params".format( + params_to_string(params), + round(100 * params / total_params, precision) if total_params else 0), + "{} = {:g}% MACs".format(macs_to_string(macs), + round(100 * macs / total_macs, precision) if total_macs else 0), + "{} = {:g}% FLOPs".format(flops_to_string(flops), + round(100 * flops / total_flops, precision) if total_flops else 0), + ] + original_extra_repr = module.original_extra_repr() + if original_extra_repr: + items.append(original_extra_repr) + return ", ".join(items) + + def add_extra_repr(module): + flops_extra_repr = flops_repr.__get__(module) + if module.extra_repr != flops_extra_repr: + module.original_extra_repr = module.extra_repr + module.extra_repr = flops_extra_repr + assert module.extra_repr != module.original_extra_repr + + def del_extra_repr(module): + if hasattr(module, "original_extra_repr"): + module.extra_repr = module.original_extra_repr + del module.original_extra_repr + + self.model.apply(add_extra_repr) + + if print_detailed: + prints.append( + "\n-------------------------------- Detailed Calculated FLOPs Results --------------------------------") + prints.append( + "Each module caculated is listed after its name in the following order: \nparams, percentage of total params, MACs, percentage of total MACs, FLOPS, percentage of total FLOPs" + ) + prints.append( + "\nNote: 1. A module can have torch.nn.module or torch.nn.functional to compute logits (e.g. CrossEntropyLoss). \n They are not counted as submodules in calflops and not to be printed out. However they make up the difference between a parent's MACs and the sum of its submodules'.\n2. Number of floating-point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput.\n" + ) + prints.append(str(self.model)) + + self.model.apply(del_extra_repr) + + prints.append( + "---------------------------------------------------------------------------------------------------") + + return_print = "" + for line in prints: + if print_results: + print(line) + return_print += line + "\n" + return return_print + + def print_model_pipline(self, units=None, precision=DEFAULT_PRECISION, print_detailed=True): + """Prints the model graph with the calculateing pipline attached to each module. + + Args: + module_depth (int, optional): The depth of the model to which to print the aggregated module information. When set to -1, it prints information from the top to the innermost modules (the maximum depth). + top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified. + print_detailed (bool, optional): Whether to print the detailed model profile. + """ + if not self.pipline_started: + return + + total_flops = self.get_total_flops() + total_macs = self.get_total_macs() + total_params = self.get_total_params() + + self.flops = total_flops + self.macs = total_macs + self.params = total_params + + print("\n------------------------------------- Calculate Flops Results -------------------------------------") + + print("Notations:\n" + "number of parameters (Params), number of multiply-accumulate operations(MACs),\n" + "number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),\n" + "fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),\n" + "default model backpropagation takes %.2f times as much computation as forward propagation.\n" % self.compute_bp_factor) + + line_fmt = '{:<70} {:<8}' + + print(line_fmt.format('Total Training Params: ', params_to_string(total_params))) + + print(line_fmt.format('fwd MACs: ', macs_to_string(total_macs, units=units, + precision=precision))) + print(line_fmt.format('fwd FLOPs: ', flops_to_string(total_flops, units=units, + precision=precision))) + print(line_fmt.format('fwd+bwd MACs: ', macs_to_string(total_macs * (1 + self.compute_bp_factor), + units=units, precision=precision))) + print(line_fmt.format('fwd+bwd FLOPs: ', flops_to_string(total_flops * (1 + self.compute_bp_factor), + units=units, precision=precision))) + + def flops_repr(module): + params = module.__params__ + flops = get_module_flops(module) + macs = get_module_macs(module) + items = [ + "{} = {:g}% Params".format( + params_to_string(params), + round(100 * params / total_params, precision) if total_params else 0), + "{} = {:g}% MACs".format(macs_to_string(macs), + round(100 * macs / total_macs, precision) if total_macs else 0), + "{} = {:g}% FLOPs".format(flops_to_string(flops), + round(100 * macs / total_flops, precision) if total_flops else 0), + ] + original_extra_repr = module.original_extra_repr() + if original_extra_repr: + items.append(original_extra_repr) + return ", ".join(items) + + def add_extra_repr(module): + flops_extra_repr = flops_repr.__get__(module) + if module.extra_repr != flops_extra_repr: + module.original_extra_repr = module.extra_repr + module.extra_repr = flops_extra_repr + assert module.extra_repr != module.original_extra_repr + + def del_extra_repr(module): + if hasattr(module, "original_extra_repr"): + module.extra_repr = module.original_extra_repr + del module.original_extra_repr + + self.model.apply(add_extra_repr) + + if print_detailed: + print( + "\n-------------------------------- Detailed Calculated FLOPs Results --------------------------------") + print( + "Each module caculated is listed after its name in the following order: \nparams, percentage of total params, MACs, percentage of total MACs, FLOPS, percentage of total FLOPs" + ) + print( + "\nNote: 1. A module can have torch.nn.module or torch.nn.functional to compute logits (e.g. CrossEntropyLoss). \n They are not counted as submodules in calflops and not to be printed out. However they make up the difference between a parent's MACs and the sum of its submodules'.\n2. Number of floating-point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput.\n" + ) + print(self.model) + + self.model.apply(del_extra_repr) + + print("---------------------------------------------------------------------------------------------------") diff --git a/chameleon/tools/calflops/flops_counter.py b/chameleon/tools/calflops/flops_counter.py new file mode 100644 index 0000000..23a29f6 --- /dev/null +++ b/chameleon/tools/calflops/flops_counter.py @@ -0,0 +1,189 @@ +# !usr/bin/env python +# -*- coding:utf-8 -*- + +''' + Description : + Version : 1.0 + Author : MrYXJ + Mail : yxj2017@gmail.com + Github : https://github.com/MrYxJ + Date : 2023-08-19 10:28:55 + LastEditTime : 2023-09-07 23:39:17 + Copyright (C) 2023 mryxj. All rights reserved. +''' + +import torch +import torch.nn as nn + +from .calculate_pipline import CalFlopsPipline +from .utils import (flops_to_string, generate_transformer_input, + macs_to_string, params_to_string) + + +def calculate_flops(model, + input_shape=None, + transformer_tokenizer=None, + args=[], + kwargs={}, + forward_mode="forward", + include_backPropagation=False, + compute_bp_factor=2.0, + print_results=True, + print_detailed=True, + output_as_string=True, + output_precision=2, + output_unit=None, + ignore_modules=None, + is_sparse=False): + """Returns the total floating-point operations, MACs, and parameters of a model. + + Args: + model ([torch.nn.Module]): The model of input must be a PyTorch model. + input_shape (tuple, optional): Input shape to the model. If args and kwargs is empty, the model takes a tensor with this shape as the only positional argument. Default to []. + transformers_tokenizer (None, optional): Transforemrs Toekenizer must be special if model type is transformers and args、kwargs is empty. Default to None + args (list, optional): list of positional arguments to the model, such as bert input args is [input_ids, token_type_ids, attention_mask]. Default to [] + kwargs (dict, optional): dictionary of keyword arguments to the model, such as bert input kwargs is {'input_ids': ..., 'token_type_ids':..., 'attention_mask':...}. Default to {} + forward_mode (str, optional): To determine the mode of model inference, Default to 'forward'. And use 'generate' if model inference uses model.generate(). + include_backPropagation (bool, optional): Decides whether the final return FLOPs computation includes the computation for backpropagation. + compute_bp_factor (float, optional): The model backpropagation is a multiple of the forward propagation computation. Default to 2. + print_results (bool, optional): Whether to print the model profile. Defaults to True. + print_detailed (bool, optional): Whether to print the detailed model profile. Defaults to True. + output_as_string (bool, optional): Whether to print the output as string. Defaults to True. + output_precision (int, optional) : Output holds the number of decimal places if output_as_string is True. Default to 2. + output_unit (str, optional): The unit used to output the result value, such as T, G, M, and K. Default is None, that is the unit of the output decide on value. + ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. + is_sparse (bool, optional): Whether to exclude sparse matrix flops. Defaults to False. + + Example: + .. code-block:: python + from calflops import calculate_flops + + # Deep Learning Model, such as alexnet. + from torchvision import models + + model = models.alexnet() + batch_size = 1 + flops, macs, params = calculate_flops(model=model, + input_shape=(batch_size, 3, 224, 224), + output_as_string=True, + output_precision=4) + print("Alexnet FLOPs:%s MACs:%s Params:%s \n" %(flops, macs, params)) + #Alexnet FLOPs:1.4297 GFLOPS MACs:714.188 MMACs Params:61.1008 M + + # Transformers Model, such as bert. + from transformers import AutoModel + from transformers import AutoTokenizer + batch_size = 1 + max_seq_length = 128 + model_name = "hfl/chinese-roberta-wwm-ext/" + model_save = "../pretrain_models/" + model_name + model = AutoModel.from_pretrained(model_save) + tokenizer = AutoTokenizer.from_pretrained(model_save) + flops, macs, params = calculate_flops(model=model, + input_shape=(batch_size, max_seq_length), + transformer_tokenizer=tokenizer) + print("Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n" %(flops, macs, params)) + #Bert(hfl/chinese-roberta-wwm-ext) FLOPs:22.36 GFLOPS MACs:11.17 GMACs Params:102.27 M + + # Large Languase Model, such as llama2-7b. + from transformers import LlamaTokenizer + from transformers import LlamaForCausalLM + batch_size = 1 + max_seq_length = 128 + model_name = "llama2_hf_7B" + model_save = "../model/" + model_name + model = LlamaForCausalLM.from_pretrained(model_save) + tokenizer = LlamaTokenizer.from_pretrained(model_save) + flops, macs, params = calculate_flops(model=model, + input_shape=(batch_size, max_seq_length), + transformer_tokenizer=tokenizer) + print("Llama2(7B) FLOPs:%s MACs:%s Params:%s \n" %(flops, macs, params)) + #Llama2(7B) FLOPs:1.7 TFLOPS MACs:850.00 GMACs Params:6.74 B + + Returns: + The number of floating-point operations, multiply-accumulate operations (MACs), and parameters in the model. + """ + + assert isinstance(model, nn.Module), "model must be a PyTorch module" + # assert transformers_tokenizer and auto_generate_transformers_input and "transformers" in str(type(model)), "The model must be a transformers model if args of auto_generate_transformers_input is True and transformers_tokenizer is not None" + model.eval() + + is_transformer = True if "transformers" in str(type(model)) else False + + calculate_flops_pipline = CalFlopsPipline(model=model, + include_backPropagation=include_backPropagation, + compute_bp_factor=compute_bp_factor, + is_sparse=is_sparse) + calculate_flops_pipline.start_flops_calculate(ignore_list=ignore_modules) + + device = next(model.parameters()).device + model = model.to(device) + + if input_shape is not None: + assert len(args) == 0 and len( + kwargs) == 0, "args and kwargs must be empty value if input_shape is not None, then will be generate random input by inpust_shape" + assert type(input_shape) is tuple, "input_shape must be a tuple" + assert len(input_shape) >= 1, "input_shape must have at least one element" + + if transformer_tokenizer is None: # model is not transformers model + assert is_transformer is False, "the model is must not transformer model if input_shape is not None and transformer_tokenizer is None" + try: + input = torch.ones(()).new_empty( + (*input_shape,), + dtype=next(model.parameters()).dtype, + device=device, + ) + except StopIteration: + input = torch.ones(()).new_empty((*input_shape,)) + args = [input] + else: + assert len( + input_shape) == 2, "the format of input_shape must be (batch_size, seq_len) if model is transformers model and auto_generate_transformers_input if True" + kwargs = generate_transformer_input(input_shape=input_shape, + model_tokenizer=transformer_tokenizer, + device=device) + else: + assert transformer_tokenizer or (len(args) > 0 or len( + kwargs) > 0), "input_shape or args or kwargs one of there parameters must specified if auto_generate_input is False" + if transformer_tokenizer: + kwargs = generate_transformer_input(input_shape=None, + model_tokenizer=transformer_tokenizer, + device=device) + + if kwargs: + for key, value in kwargs.items(): + if torch.is_tensor(value): + kwargs[key] = value.to(device) + else: + kwargs = {} + for index in range(len(args)): + args[index] = args[index].to(device) + + if forward_mode == 'forward': + _ = model(*args, **kwargs) + elif forward_mode == 'generate': + _ = model.generate(*args, **kwargs) + else: + raise NotImplementedError("forward_mode should be either forward or generate") + + flops = calculate_flops_pipline.get_total_flops() + macs = calculate_flops_pipline.get_total_macs() + params = calculate_flops_pipline.get_total_params() + + if print_results: + calculate_flops_pipline.print_model_pipline(units=output_unit, + precision=output_precision, + print_detailed=print_detailed) + + calculate_flops_pipline.end_flops_calculate() + + if include_backPropagation: + flops = flops * (1 + compute_bp_factor) + macs = macs * (1 + compute_bp_factor) + + if output_as_string: + return flops_to_string(flops, units=output_unit, precision=output_precision), \ + macs_to_string(macs, units=output_unit, precision=output_precision), \ + params_to_string(params, units=output_unit, precision=output_precision) + + return flops, macs, params diff --git a/chameleon/tools/calflops/pytorch_ops.py b/chameleon/tools/calflops/pytorch_ops.py new file mode 100644 index 0000000..c337875 --- /dev/null +++ b/chameleon/tools/calflops/pytorch_ops.py @@ -0,0 +1,598 @@ +# !usr/bin/env python +# -*- coding:utf-8 -*- + +''' + Description : + Version : 1.0 + Author : MrYXJ + Mail : yxj2017@gmail.com + Github : https://github.com/MrYxJ + Date : 2023-08-19 22:34:47 + LastEditTime : 2023-08-23 11:17:33 + Copyright (C) 2023 mryxj. All rights reserved. +''' + +''' +The part of code is inspired by ptflops and deepspeed profiling. +''' + + +from collections import OrderedDict +from typing import List, Optional +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +Tensor = torch.Tensor + + +def _prod(dims): + p = 1 + for v in dims: + p *= v + return p + + +def _linear_flops_compute(input, weight, bias=None): + out_features = weight.shape[0] + macs = input.numel() * out_features + return 2 * macs, macs + +# Activation just calculate FLOPs, MACs is 0 + + +def _relu_flops_compute(input, inplace=False): + return input.numel(), 0 + + +def _prelu_flops_compute(input: Tensor, weight: Tensor): + return input.numel(), 0 + + +def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False): + return input.numel(), 0 + + +def _leaky_relu_flops_compute(input: Tensor, negative_slope: float = 0.01, inplace: bool = False): + return input.numel(), 0 + + +def _relu6_flops_compute(input: Tensor, inplace: bool = False): + return input.numel(), 0 + + +def _silu_flops_compute(input: Tensor, inplace: bool = False): + return input.numel(), 0 + + +def _gelu_flops_compute(input, **kwargs): + return input.numel(), 0 + + +def _pool_flops_compute(input, + kernel_size, + stride=None, + padding=0, + dilation=None, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + return_indices=None): + return input.numel(), 0 + + +def _conv_flops_compute(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert weight.shape[1] * groups == input.shape[1] + + batch_size = input.shape[0] + in_channels = input.shape[1] + out_channels = weight.shape[0] + kernel_dims = list(weight.shape[2:]) + input_dims = list(input.shape[2:]) + + length = len(input_dims) + + strides = stride if type(stride) is tuple else (stride, ) * length + dilations = dilation if type(dilation) is tuple else (dilation, ) * length + if isinstance(padding, str): + if padding == 'valid': + paddings = (0, ) * length + elif padding == 'same': + paddings = () + for d, k in zip(dilations, kernel_dims): + total_padding = d * (k - 1) + paddings += (total_padding // 2, ) + elif isinstance(padding, tuple): + paddings = padding + else: + paddings = (padding, ) * length + + output_dims = [] + for idx, input_dim in enumerate(input_dims): + output_dim = (input_dim + 2 * paddings[idx] - (dilations[idx] * + (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1 + output_dims.append(output_dim) + + filters_per_channel = out_channels // groups + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel + active_elements_count = batch_size * int(_prod(output_dims)) + overall_conv_macs = conv_per_position_macs * active_elements_count + overall_conv_flops = 2 * overall_conv_macs + + bias_flops = 0 + if bias is not None: + bias_flops = out_channels * active_elements_count + + return int(overall_conv_flops + bias_flops), int(overall_conv_macs) + + +def _conv_trans_flops_compute( + input, + weight, + bias=None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, +): + batch_size = input.shape[0] + in_channels = input.shape[1] + out_channels = weight.shape[1] + kernel_dims = list(weight.shape[2:]) + input_dims = list(input.shape[2:]) + + length = len(input_dims) + + paddings = padding if type(padding) is tuple else (padding, ) * length + strides = stride if type(stride) is tuple else (stride, ) * length + dilations = dilation if type(dilation) is tuple else (dilation, ) * length + + output_dims = [] + for idx, input_dim in enumerate(input_dims): + + output_dim = (input_dim + 2 * paddings[idx] - (dilations[idx] * + (kernel_dims[idx] - 1) + 1)) // strides[idx] + 1 + output_dims.append(output_dim) + + paddings = padding if type(padding) is tuple else (padding, padding) + strides = stride if type(stride) is tuple else (stride, stride) + dilations = dilation if type(dilation) is tuple else (dilation, dilation) + + filters_per_channel = out_channels // groups + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel + active_elements_count = batch_size * int(_prod(input_dims)) + overall_conv_macs = conv_per_position_macs * active_elements_count + overall_conv_flops = 2 * overall_conv_macs + + bias_flops = 0 + if bias is not None: + bias_flops = out_channels * batch_size * int(_prod(output_dims)) + + return int(overall_conv_flops + bias_flops), int(overall_conv_macs) + + +def _batch_norm_flops_compute( + input, + running_mean, + running_var, + weight=None, + bias=None, + training=False, + momentum=0.1, + eps=1e-05, +): + has_affine = weight is not None + if training: + # estimation + return input.numel() * (5 if has_affine else 4), 0 + flops = input.numel() * (2 if has_affine else 1) + return flops, 0 + + +def _layer_norm_flops_compute( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +): + has_affine = weight is not None + # estimation + return input.numel() * (5 if has_affine else 4), 0 + + +def _group_norm_flops_compute(input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5): + has_affine = weight is not None + # estimation + return input.numel() * (5 if has_affine else 4), 0 + + +def _instance_norm_flops_compute( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +): + has_affine = weight is not None + # estimation + return input.numel() * (5 if has_affine else 4), 0 + + +def _upsample_flops_compute(*args, **kwargs): + input = args[0] + size = kwargs.get('size', None) + if size is None and len(args) > 1: + size = args[1] + + if size is not None: + if isinstance(size, tuple) or isinstance(size, list): + return int(_prod(size)), 0 + else: + return int(size), 0 + + scale_factor = kwargs.get('scale_factor', None) + if scale_factor is None and len(args) > 2: + scale_factor = args[2] + assert scale_factor is not None, "either size or scale_factor should be defined" + + flops = input.numel() + if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): + flops * int(_prod(scale_factor)) + else: + flops * scale_factor**len(input) + return flops, 0 + + +def _softmax_flops_compute(input, dim=None, _stacklevel=3, dtype=None): + return input.numel(), 0 + + +def _embedding_flops_compute( + input, + weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, +): + return 0, 0 + + +def _dropout_flops_compute(input, p=0.5, training=True, inplace=False): + return 0, 0 + + +def _matmul_flops_compute(input, other, *, out=None): + """ + Count flops for the matmul operation. + """ + macs = _prod(input.shape) * other.shape[-1] + return 2 * macs, macs + + +def _addmm_flops_compute(input, mat1, mat2, *, beta=1, alpha=1, out=None): + """ + Count flops for the addmm operation. + """ + macs = _prod(mat1.shape) * mat2.shape[-1] + return 2 * macs + _prod(input.shape), macs + + +def _einsum_flops_compute(equation, *operands): + """ + Count flops for the einsum operation. + """ + equation = equation.replace(" ", "") + input_shapes = [o.shape for o in operands] + + # Re-map equation so that same equation with different alphabet + # representations will look the same. + letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() + mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} + equation = equation.translate(mapping) + + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + flop = int(float(line.split(":")[-1])) + return flop, 0 + raise NotImplementedError("Unsupported einsum operation.") + + +def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None): + """ + Count flops for the tensor addmm operation. + """ + macs = _prod(mat1.shape) * mat2.shape[-1] + return 2 * macs + _prod(self.shape), macs + + +def _mul_flops_compute(input, other, *, out=None): + return _elementwise_flops_compute(input, other) + + +def _add_flops_compute(input, other, *, alpha=1, out=None): + return _elementwise_flops_compute(input, other) + + +def _elementwise_flops_compute(input, other): + if not torch.is_tensor(input): + if torch.is_tensor(other): + return _prod(other.shape), 0 + else: + return 1, 0 + elif not torch.is_tensor(other): + return _prod(input.shape), 0 + else: + dim_input = len(input.shape) + dim_other = len(other.shape) + max_dim = max(dim_input, dim_other) + + final_shape = [] + for i in range(max_dim): + in_i = input.shape[i] if i < dim_input else 1 + ot_i = other.shape[i] if i < dim_other else 1 + if in_i > ot_i: + final_shape.append(in_i) + else: + final_shape.append(ot_i) + flops = _prod(final_shape) + return flops, 0 + + +def wrapFunc(func, funcFlopCompute, old_functions, module_flop_count, module_mac_count): + oldFunc = func + name = func.__str__ + old_functions[name] = oldFunc + + def newFunc(*args, **kwds): + flops, macs = funcFlopCompute(*args, **kwds) + if module_flop_count: + module_flop_count[-1].append((name, flops)) + if module_mac_count and macs: + module_mac_count[-1].append((name, macs)) + return oldFunc(*args, **kwds) + + newFunc.__str__ = func.__str__ + + return newFunc + + +def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): + gates_size = w_ih.shape[0] + # matrix matrix mult ih state and internal state + flops += 2 * w_ih.shape[0] * w_ih.shape[1] - gates_size + # matrix matrix mult hh state and internal state + flops += 2 * w_hh.shape[0] * w_hh.shape[1] - gates_size + if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): + # add both operations + flops += rnn_module.hidden_size + elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): + # hadamard of r + flops += rnn_module.hidden_size + # adding operations from both states + flops += rnn_module.hidden_size * 3 + # last two hadamard _product and add + flops += rnn_module.hidden_size * 3 + elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): + # adding operations from both states + flops += rnn_module.hidden_size * 4 + # two hadamard _product and add for C state + flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + # final hadamard + flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + return flops + + +def _rnn_forward_hook(rnn_module, input, output): + flops = 0 + # input is a tuple containing a sequence to process and (optionally) hidden state + inp = input[0] + batch_size = inp.shape[0] + seq_length = inp.shape[1] + num_layers = rnn_module.num_layers + + for i in range(num_layers): + w_ih = rnn_module.__getattr__("weight_ih_l" + str(i)) + w_hh = rnn_module.__getattr__("weight_hh_l" + str(i)) + if i == 0: + input_size = rnn_module.input_size + else: + input_size = rnn_module.hidden_size + flops = _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) + if rnn_module.bias: + b_ih = rnn_module.__getattr__("bias_ih_l" + str(i)) + b_hh = rnn_module.__getattr__("bias_hh_l" + str(i)) + flops += b_ih.shape[0] + b_hh.shape[0] + + flops *= batch_size + flops *= seq_length + if rnn_module.bidirectional: + flops *= 2 + rnn_module.__flops__ += int(flops) + + +def _rnn_cell_forward_hook(rnn_cell_module, input, output): + flops = 0 + inp = input[0] + batch_size = inp.shape[0] + w_ih = rnn_cell_module.__getattr__("weight_ih") + w_hh = rnn_cell_module.__getattr__("weight_hh") + input_size = inp.shape[1] + flops = _rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) + if rnn_cell_module.bias: + b_ih = rnn_cell_module.__getattr__("bias_ih") + b_hh = rnn_cell_module.__getattr__("bias_hh") + flops += b_ih.shape[0] + b_hh.shape[0] + + flops *= batch_size + rnn_cell_module.__flops__ += int(flops) + + +MODULE_HOOK_MAPPING = { + # RNN + nn.RNN: _rnn_forward_hook, + nn.GRU: _rnn_forward_hook, + nn.LSTM: _rnn_forward_hook, + nn.RNNCell: _rnn_cell_forward_hook, + nn.LSTMCell: _rnn_cell_forward_hook, + nn.GRUCell: _rnn_cell_forward_hook, +} + + +def _patch_functionals(old_functions, module_flop_count, module_mac_count): + # FC + F.linear = wrapFunc(F.linear, _linear_flops_compute, old_functions, module_flop_count, module_mac_count) + # convolutions + F.conv1d = wrapFunc(F.conv1d, _conv_flops_compute, old_functions, module_flop_count, module_mac_count) + F.conv2d = wrapFunc(F.conv2d, _conv_flops_compute, old_functions, module_flop_count, module_mac_count) + F.conv3d = wrapFunc(F.conv3d, _conv_flops_compute, old_functions, module_flop_count, module_mac_count) + + # conv transposed + F.conv_transpose1d = wrapFunc(F.conv_transpose1d, _conv_trans_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.conv_transpose2d = wrapFunc(F.conv_transpose2d, _conv_trans_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.conv_transpose3d = wrapFunc(F.conv_transpose3d, _conv_trans_flops_compute, + old_functions, module_flop_count, module_mac_count) + + # activations + F.relu = wrapFunc(F.relu, _relu_flops_compute, old_functions, module_flop_count, module_mac_count) + F.prelu = wrapFunc(F.prelu, _prelu_flops_compute, old_functions, module_flop_count, module_mac_count) + F.elu = wrapFunc(F.elu, _elu_flops_compute, old_functions, module_flop_count, module_mac_count) + F.leaky_relu = wrapFunc(F.leaky_relu, _leaky_relu_flops_compute, old_functions, module_flop_count, module_mac_count) + F.relu6 = wrapFunc(F.relu6, _relu6_flops_compute, old_functions, module_flop_count, module_mac_count) + if hasattr(F, "silu"): + F.silu = wrapFunc(F.silu, _silu_flops_compute, old_functions, module_flop_count, module_mac_count) + F.gelu = wrapFunc(F.gelu, _gelu_flops_compute, old_functions, module_flop_count, module_mac_count) + + # Normalizations + F.batch_norm = wrapFunc(F.batch_norm, _batch_norm_flops_compute, old_functions, module_flop_count, module_mac_count) + F.layer_norm = wrapFunc(F.layer_norm, _layer_norm_flops_compute, old_functions, module_flop_count, module_mac_count) + F.instance_norm = wrapFunc(F.instance_norm, _instance_norm_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.group_norm = wrapFunc(F.group_norm, _group_norm_flops_compute, old_functions, module_flop_count, module_mac_count) + + # poolings + F.avg_pool1d = wrapFunc(F.avg_pool1d, _pool_flops_compute, old_functions, module_flop_count, module_mac_count) + F.avg_pool2d = wrapFunc(F.avg_pool2d, _pool_flops_compute, old_functions, module_flop_count, module_mac_count) + F.avg_pool3d = wrapFunc(F.avg_pool3d, _pool_flops_compute, old_functions, module_flop_count, module_mac_count) + F.max_pool1d = wrapFunc(F.max_pool1d, _pool_flops_compute, old_functions, module_flop_count, module_mac_count) + F.max_pool2d = wrapFunc(F.max_pool2d, _pool_flops_compute, old_functions, module_flop_count, module_mac_count) + F.max_pool3d = wrapFunc(F.max_pool3d, _pool_flops_compute, old_functions, module_flop_count, module_mac_count) + F.adaptive_avg_pool1d = wrapFunc(F.adaptive_avg_pool1d, _pool_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.adaptive_avg_pool2d = wrapFunc(F.adaptive_avg_pool2d, _pool_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.adaptive_avg_pool3d = wrapFunc(F.adaptive_avg_pool3d, _pool_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.adaptive_max_pool1d = wrapFunc(F.adaptive_max_pool1d, _pool_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.adaptive_max_pool2d = wrapFunc(F.adaptive_max_pool2d, _pool_flops_compute, + old_functions, module_flop_count, module_mac_count) + F.adaptive_max_pool3d = wrapFunc(F.adaptive_max_pool3d, _pool_flops_compute, + old_functions, module_flop_count, module_mac_count) + + # upsample + F.upsample = wrapFunc(F.upsample, _upsample_flops_compute, old_functions, module_flop_count, module_mac_count) + F.interpolate = wrapFunc(F.interpolate, _upsample_flops_compute, old_functions, module_flop_count, module_mac_count) + + # softmax + F.softmax = wrapFunc(F.softmax, _softmax_flops_compute, old_functions, module_flop_count, module_mac_count) + + # embedding + F.embedding = wrapFunc(F.embedding, _embedding_flops_compute, old_functions, module_flop_count, module_mac_count) + + +def _patch_tensor_methods(old_functions, module_flop_count, module_mac_count): + torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute, old_functions, module_flop_count, module_mac_count) + torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute, + old_functions, module_flop_count, module_mac_count) + # torch.mm = wrapFunc(torch.mm, _matmul_flops_compute, old_functions, module_flop_count, module_mac_count) + # torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute, old_functions, module_flop_count, module_mac_count) + # torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute, old_functions, module_flop_count, module_mac_count) + # torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute, old_functions, module_flop_count, module_mac_count) + + torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute, old_functions, module_flop_count, module_mac_count) + torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute, + old_functions, module_flop_count, module_mac_count) + + torch.mul = wrapFunc(torch.mul, _mul_flops_compute, old_functions, module_flop_count, module_mac_count) + torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute, + old_functions, module_flop_count, module_mac_count) + + torch.add = wrapFunc(torch.add, _add_flops_compute, old_functions, module_flop_count, module_mac_count) + torch.Tensor.add = wrapFunc(torch.Tensor.add, _add_flops_compute, + old_functions, module_flop_count, module_mac_count) + + torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute, old_functions, module_flop_count, module_mac_count) + + torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute, + old_functions, module_flop_count, module_mac_count) + + +def _reload_functionals(old_functions): + # torch.nn.functional does not support importlib.reload() + F.linear = old_functions[F.linear.__str__] + F.conv1d = old_functions[F.conv1d.__str__] + F.conv2d = old_functions[F.conv2d.__str__] + F.conv3d = old_functions[F.conv3d.__str__] + F.conv_transpose1d = old_functions[F.conv_transpose1d.__str__] + F.conv_transpose2d = old_functions[F.conv_transpose2d.__str__] + F.conv_transpose3d = old_functions[F.conv_transpose3d.__str__] + F.relu = old_functions[F.relu.__str__] + F.prelu = old_functions[F.prelu.__str__] + F.elu = old_functions[F.elu.__str__] + F.leaky_relu = old_functions[F.leaky_relu.__str__] + F.relu6 = old_functions[F.relu6.__str__] + if hasattr(F, "silu"): + F.silu = old_functions[F.silu.__str__] + F.gelu = old_functions[F.gelu.__str__] + F.batch_norm = old_functions[F.batch_norm.__str__] + F.layer_norm = old_functions[F.layer_norm.__str__] + F.instance_norm = old_functions[F.instance_norm.__str__] + F.group_norm = old_functions[F.group_norm.__str__] + F.avg_pool1d = old_functions[F.avg_pool1d.__str__] + F.avg_pool2d = old_functions[F.avg_pool2d.__str__] + F.avg_pool3d = old_functions[F.avg_pool3d.__str__] + F.max_pool1d = old_functions[F.max_pool1d.__str__] + F.max_pool2d = old_functions[F.max_pool2d.__str__] + F.max_pool3d = old_functions[F.max_pool3d.__str__] + F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__str__] + F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__str__] + F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__str__] + F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__str__] + F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__str__] + F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__str__] + F.upsample = old_functions[F.upsample.__str__] + F.interpolate = old_functions[F.interpolate.__str__] + F.softmax = old_functions[F.softmax.__str__] + F.embedding = old_functions[F.embedding.__str__] + + +def _reload_tensor_methods(old_functions): + torch.matmul = old_functions[torch.matmul.__str__] + torch.Tensor.matmul = old_functions[torch.Tensor.matmul.__str__] + # torch.mm = old_functions[torch.mm.__str__] + # torch.Tensor.mm = old_functions[torch.Tensor.mm.__str__] + # torch.bmm = old_functions[torch.matmul.__str__] + # torch.Tensor.bmm = old_functions[torch.Tensor.bmm.__str__] + torch.addmm = old_functions[torch.addmm.__str__] + torch.Tensor.addmm = old_functions[torch.Tensor.addmm.__str__] + torch.mul = old_functions[torch.mul.__str__] + torch.Tensor.mul = old_functions[torch.Tensor.mul.__str__] + torch.add = old_functions[torch.add.__str__] + torch.Tensor.add = old_functions[torch.Tensor.add.__str__] + torch.einsum = old_functions[torch.einsum.__str__] + torch.baddbmm = old_functions[torch.baddbmm.__str__] diff --git a/chameleon/tools/calflops/readme.md b/chameleon/tools/calflops/readme.md new file mode 100644 index 0000000..9e74a8a --- /dev/null +++ b/chameleon/tools/calflops/readme.md @@ -0,0 +1,5 @@ +# Calflops + +**Many thanks to the author fo [Calflops](https://github.com/MrYxJ/calculate-flops.pytorch)** + +We remove the hugginface dependency and make it more user-friendly. diff --git a/chameleon/tools/calflops/utils.py b/chameleon/tools/calflops/utils.py new file mode 100644 index 0000000..9aced3f --- /dev/null +++ b/chameleon/tools/calflops/utils.py @@ -0,0 +1,238 @@ +# !usr/bin/env python +# -*- coding:utf-8 -*- + +''' + Description : + Version : 1.0 + Author : MrYXJ + Mail : yxj2017@gmail.com + Github : https://github.com/MrYxJ + Date : 2023-08-19 11:01:23 + LastEditTime : 2023-09-05 15:51:50 + Copyright (C) 2023 mryxj. All rights reserved. +''' + +import importlib + +import torch + +DEFAULT_PRECISION = 2 + + +def generate_transformer_input(model_tokenizer, input_shape, device): + """Automatically generates data in the form of transformes model input format. + + Args: + input_shape (tuple):transformers model input shape: (batch_size, seq_len). + tokenizer (transformer.model.tokenization): transformers model tokenization.tokenizer. + + Returns: + dict: data format of transformers model input, it is a dict contain 'input_ids', 'attention_mask', sometime contain 'token_type_ids'. + """ + + if input_shape is None: + input_shape = [1, 128] # defautl (batch_size=1, seq_len=128) + + max_length = input_shape[1] + model_input_ids = [] + model_attention_mask = [] + model_token_type_ids = [] + model_position_ids = [] + + inp_seq = "" + for _ in range(input_shape[0]): + inputs = model_tokenizer.encode_plus( + inp_seq, + add_special_tokens=True, + truncation_strategy='longest_first', + ) + origin_length = len(inputs["input_ids"]) + padding_length = max_length - origin_length + + for key in inputs.keys(): + if key == "input_ids": + input_ids = inputs["input_ids"] + pad_token = model_tokenizer.pad_token_id if model_tokenizer.pad_token_id else 0 + input_ids = input_ids + ([pad_token] * padding_length) + assert len(input_ids) == max_length, "len(input_ids) must equal max_length" + model_input_ids.append(input_ids) + elif key == "attention_mask": + attention_mask = [1] * origin_length + attention_mask = attention_mask + ([0] * padding_length) + assert len(attention_mask) == max_length, "len(attention_mask) must equal max_length" + model_attention_mask.append(attention_mask) + elif key == "token_type_ids": + token_type_ids = inputs['token_type_ids'] + pad_token_segment_id = 0 + token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) + assert len(token_type_ids) == max_length, "len(token_type_ids) must equal max_length" + model_token_type_ids.append(token_type_ids) + elif key == "position_ids": # chatglm2 use position id + position_ids = inputs['position_ids'] + for i in range(origin_length, max_length): + position_ids.append(i) + assert len(position_ids) == max_length, "len(position_ids) must equal max_length" + model_position_ids.append(position_ids) + + # Batch size input_shape[0], sequence length input_shape[128] + inputs = {} + if len(model_input_ids) > 0: + inputs.update({"input_ids": torch.tensor(model_input_ids).to(device)}) + if len(model_attention_mask) > 0: + inputs.update({"attention_mask": torch.tensor(model_attention_mask).to(device)}) + if len(model_token_type_ids) > 0: + inputs.update({'token_type_ids': torch.tensor(model_token_type_ids).to(device)}) + if len(model_position_ids) > 0: + inputs.update({'position_ids': torch.tensor(model_position_ids).to(device)}) + + return inputs + + +def number_to_string(num, units=None, precision=DEFAULT_PRECISION): + if units is None: + if num >= 1e12: + magnitude, units = 1e12, "T" + elif num >= 1e9: + magnitude, units = 1e9, "G" + elif num >= 1e6: + magnitude, units = 1e6, "M" + elif num >= 1e3: + magnitude, units = 1e3, "K" + elif num >= 1 or num == 0: + magnitude, units = 1, "" + elif num >= 1e-3: + magnitude, units = 1e-3, "m" + else: + magnitude, units = 1e-6, "u" + else: + if units == "T": + magnitude = 1e12 + elif units == "G": + magnitude = 1e9 + elif units == "M": + magnitude = 1e6 + elif units == "K": + magnitude = 1e3 + elif units == "m": + magnitude = 1e-3 + elif units == "u": + magnitude = 1e-6 + else: + magnitude = 1 + return f"{round(num / magnitude, precision):g} {units}" + + +def macs_to_string(macs, units=None, precision=DEFAULT_PRECISION): + """Converts macs in numeric form to string form. + + Args: + macs (int): Calculate the results of the model macs in numerical form. + units (str, optional): The unit of macs after conversion to string representation, such as TMACs、GMACs、MMACs、KMACs + precision (int, optional): The number of digits of the result is preserved. Defaults to DEFAULT_PRECISION. + + Returns: + string: The string representation of macs. + """ + return f"{number_to_string(macs, units=units, precision=precision)}MACs" + + +def flops_to_string(flops, units=None, precision=DEFAULT_PRECISION): + """Converts flops in numeric form to string form. + + Args: + flops (int): Calculate the results of the model flops in numerical form. + units (str, optional): The unit of flops after conversion to string representation, such as TFLOPs,GFLOPs,MFLOPs,KFLOPs. + precision (int, optional): The number of digits of the result is preserved. Defaults to DEFAULT_PRECISION. + + Returns: + string: The string representation of flops. + """ + return f"{number_to_string(flops, units=units, precision=precision)}FLOPS" + + +def bytes_to_string(b, units=None, precision=DEFAULT_PRECISION): + """Converts bytes in numeric form to string form. + + Args: + b (int): Calculate the results of the bytes in numerical form. + units (str, optional): The unit of bytes after conversion to string representation, such as TB,GB,MB,KB. + precision (int, optional): The number of digits of the result is preserved. Defaults to DEFAULT_PRECISION. + + Returns: + string: The string representation of bytes. + """ + return f"{number_to_string(b, units=units, precision=precision)}B" + + +def params_to_string(params_num, units=None, precision=DEFAULT_PRECISION): + """Converts params in numeric form to string form. + + Args: + params_num (int): Calculate the results of the model param in numerical form. + units (str, optional): The unit of params after conversion to string representation. + precision (int, optional): The number of digits of the result is preserved. Defaults to DEFAULT_PRECISION. + + Returns: + string: The string representation of params. + """ + units = units.replace("B", "G") if units else units + return number_to_string(params_num, units=units, precision=precision).replace("G", "B").strip() + + +def get_module_flops(module, is_sparse=False): + """Recursively compute the FLOP s of the model + + Args: + module (pytorch module): model format must be pytorch + is_sparse (bool, Optional): Whether to exclude sparse weight. Defaults to False. + + Returns: + int: The sum of the entire model flops + """ + sum_flops = module.__flops__ * sum( + p.count_nonzero().item() for p in module.parameters() if p.requires_grad + ) / sum(p.numel() for p in module.parameters() if p.requires_grad) if is_sparse else module.__flops__ + # iterate over immediate children modules + for child in module.children(): + sum_flops += get_module_flops(child, is_sparse=is_sparse) + return sum_flops + + +def get_module_macs(module, is_sparse=False): + """Recursively compute the macs s of the model + + Args: + module (pytorch module): model format must be pytorch + is_sparse (bool, Optional): Whether to exclude sparse weight. Defaults to False. + + Returns: + int: The sum of the entire model macs + """ + sum_macs = module.__macs__ * sum( + p.count_nonzero().item() for p in module.parameters() if p.requires_grad + ) / sum(p.numel() for p in module.parameters() if p.requires_grad) if is_sparse else module.__macs__ + # iterate over immediate children modules + for child in module.children(): + sum_macs += get_module_macs(child, is_sparse=is_sparse) + return sum_macs + + +def convert_bytes(size): + "Converts `size` from bytes to the largest possible unit" + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if size < 1024.0: + return f"{round(size, 2)} {x}" + size /= 1024.0 + + return f"{round(size, 2)} PB" + + +def _is_package_available(pkg_name): + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + if package_exists: + try: + _ = importlib.metadata.metadata(pkg_name) + return True + except importlib.metadata.PackageNotFoundError: + return False diff --git a/chameleon/tools/cpuinfo.py b/chameleon/tools/cpuinfo.py deleted file mode 100644 index b6c167d..0000000 --- a/chameleon/tools/cpuinfo.py +++ /dev/null @@ -1,874 +0,0 @@ -################################################################### -# cpuinfo - Get information about CPU -# -# License: BSD -# Author: Pearu Peterson -# -# See LICENSES/cpuinfo.txt for details about copyright and -# rights to use. -#################################################################### -""" -cpuinfo -Copyright 2002 Pearu Peterson all rights reserved, -Pearu Peterson -Permission to use, modify, and distribute this software is given under the -terms of the NumPy (BSD style) license. See LICENSE.txt that came with -this distribution for specifics. -NO WARRANTY IS EXPRESSED OR IMPLIED. USE AT YOUR OWN RISK. -Pearu Peterson - -Ref: https://github.com/pydata/numexpr/blob/master/numexpr/cpuinfo.py - -Usage: - >>> from cpuinfo import cpuinfo - >>> info = cpuinfo() # len(info) equals to num of cpus. - >>> print(list(info[0].keys())) - >>> { - 'processor', - 'vendor_id', - 'cpu family', - 'model', - 'model name', - 'stepping', - 'microcode', - 'cpu MHz', - 'cache size', - 'physical id', - 'siblings', - 'core id', - 'cpu cores', - 'apicid', - 'initial apicid', - 'fpu', - 'fpu_exception', - 'cpuid level', - 'wp', - 'flags', - 'vmx flags', - 'bugs', - 'bogomips', - 'clflush size', - 'cache_alignment', - 'address sizes', - 'power management' - } -""" - -__all__ = ['cpuinfo'] - -import inspect -import os -import platform -import re -import subprocess -import sys -import warnings - -is_cpu_amd_intel = False # DEPRECATION WARNING: WILL BE REMOVED IN FUTURE RELEASE - -def getoutput(cmd, successful_status=(0,), stacklevel=1): - try: - p = subprocess.Popen(cmd, stdout=subprocess.PIPE) - output, _ = p.communicate() - status = p.returncode - except EnvironmentError as e: - warnings.warn(str(e), UserWarning, stacklevel=stacklevel) - return False, '' - if os.WIFEXITED(status) and os.WEXITSTATUS(status) in successful_status: - return True, output - return False, output - - -def command_info(successful_status=(0,), stacklevel=1, **kw): - info = {} - for key in kw: - ok, output = getoutput(kw[key], successful_status=successful_status, - stacklevel=stacklevel + 1) - if ok: - info[key] = output.strip() - return info - - -def command_by_line(cmd, successful_status=(0,), stacklevel=1): - ok, output = getoutput(cmd, successful_status=successful_status, - stacklevel=stacklevel + 1) - if not ok: - return - - # XXX: check - output = output.decode('ascii') - - for line in output.splitlines(): - yield line.strip() - - -def key_value_from_command(cmd, sep, successful_status=(0,), - stacklevel=1): - d = {} - for line in command_by_line(cmd, successful_status=successful_status, - stacklevel=stacklevel + 1): - l = [s.strip() for s in line.split(sep, 1)] - if len(l) == 2: - d[l[0]] = l[1] - return d - - -class CPUInfoBase(object): - """Holds CPU information and provides methods for requiring - the availability of various CPU features. - """ - - def _try_call(self, func): - try: - return func() - except: - pass - - def __getattr__(self, name): - if not name.startswith('_'): - if hasattr(self, '_' + name): - attr = getattr(self, '_' + name) - if inspect.ismethod(attr): - return lambda func=self._try_call, attr=attr: func(attr) - else: - return lambda: None - raise AttributeError(name) - - def _getNCPUs(self): - return 1 - - def __get_nbits(self): - abits = platform.architecture()[0] - nbits = re.compile(r'(\d+)bit').search(abits).group(1) - return nbits - - def _is_32bit(self): - return self.__get_nbits() == '32' - - def _is_64bit(self): - return self.__get_nbits() == '64' - - -class LinuxCPUInfo(CPUInfoBase): - info = None - - def __init__(self): - if self.info is not None: - return - info = [{}] - ok, output = getoutput(['uname', '-m']) - if ok: - info[0]['uname_m'] = output.strip() - try: - fo = open('/proc/cpuinfo') - except EnvironmentError as e: - warnings.warn(str(e), UserWarning) - else: - for line in fo: - name_value = [s.strip() for s in line.split(':', 1)] - if len(name_value) != 2: - continue - name, value = name_value - if not info or name in info[-1]: # next processor - info.append({}) - info[-1][name] = value - fo.close() - self.__class__.info = info - - def _not_impl(self): - pass - - # Athlon - - def _is_AMD(self): - return self.info[0]['vendor_id'] == 'AuthenticAMD' - - def _is_AthlonK6_2(self): - return self._is_AMD() and self.info[0]['model'] == '2' - - def _is_AthlonK6_3(self): - return self._is_AMD() and self.info[0]['model'] == '3' - - def _is_AthlonK6(self): - return re.match(r'.*?AMD-K6', self.info[0]['model name']) is not None - - def _is_AthlonK7(self): - return re.match(r'.*?AMD-K7', self.info[0]['model name']) is not None - - def _is_AthlonMP(self): - return re.match(r'.*?Athlon\(tm\) MP\b', - self.info[0]['model name']) is not None - - def _is_AMD64(self): - return self.is_AMD() and self.info[0]['family'] == '15' - - def _is_Athlon64(self): - return re.match(r'.*?Athlon\(tm\) 64\b', - self.info[0]['model name']) is not None - - def _is_AthlonHX(self): - return re.match(r'.*?Athlon HX\b', - self.info[0]['model name']) is not None - - def _is_Opteron(self): - return re.match(r'.*?Opteron\b', - self.info[0]['model name']) is not None - - def _is_Hammer(self): - return re.match(r'.*?Hammer\b', - self.info[0]['model name']) is not None - - # Alpha - - def _is_Alpha(self): - return self.info[0]['cpu'] == 'Alpha' - - def _is_EV4(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'EV4' - - def _is_EV5(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'EV5' - - def _is_EV56(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'EV56' - - def _is_PCA56(self): - return self.is_Alpha() and self.info[0]['cpu model'] == 'PCA56' - - # Intel - - #XXX - _is_i386 = _not_impl - - def _is_Intel(self): - return self.info[0]['vendor_id'] == 'GenuineIntel' - - def _is_i486(self): - return self.info[0]['cpu'] == 'i486' - - def _is_i586(self): - return self.is_Intel() and self.info[0]['cpu family'] == '5' - - def _is_i686(self): - return self.is_Intel() and self.info[0]['cpu family'] == '6' - - def _is_Celeron(self): - return re.match(r'.*?Celeron', - self.info[0]['model name']) is not None - - def _is_Pentium(self): - return re.match(r'.*?Pentium', - self.info[0]['model name']) is not None - - def _is_PentiumII(self): - return re.match(r'.*?Pentium.*?II\b', - self.info[0]['model name']) is not None - - def _is_PentiumPro(self): - return re.match(r'.*?PentiumPro\b', - self.info[0]['model name']) is not None - - def _is_PentiumMMX(self): - return re.match(r'.*?Pentium.*?MMX\b', - self.info[0]['model name']) is not None - - def _is_PentiumIII(self): - return re.match(r'.*?Pentium.*?III\b', - self.info[0]['model name']) is not None - - def _is_PentiumIV(self): - return re.match(r'.*?Pentium.*?(IV|4)\b', - self.info[0]['model name']) is not None - - def _is_PentiumM(self): - return re.match(r'.*?Pentium.*?M\b', - self.info[0]['model name']) is not None - - def _is_Prescott(self): - return self.is_PentiumIV() and self.has_sse3() - - def _is_Nocona(self): - return (self.is_Intel() and - self.info[0]['cpu family'] in ('6', '15') and - # two s sse3; three s ssse3 not the same thing, this is fine - (self.has_sse3() and not self.has_ssse3()) and - re.match(r'.*?\blm\b', self.info[0]['flags']) is not None) - - def _is_Core2(self): - return (self.is_64bit() and self.is_Intel() and - re.match(r'.*?Core\(TM\)2\b', - self.info[0]['model name']) is not None) - - def _is_Itanium(self): - return re.match(r'.*?Itanium\b', - self.info[0]['family']) is not None - - def _is_XEON(self): - return re.match(r'.*?XEON\b', - self.info[0]['model name'], re.IGNORECASE) is not None - - _is_Xeon = _is_XEON - - # Power - def _is_Power(self): - return re.match(r'.*POWER.*', - self.info[0]['cpu']) is not None - - def _is_Power7(self): - return re.match(r'.*POWER7.*', - self.info[0]['cpu']) is not None - - def _is_Power8(self): - return re.match(r'.*POWER8.*', - self.info[0]['cpu']) is not None - - def _is_Power9(self): - return re.match(r'.*POWER9.*', - self.info[0]['cpu']) is not None - - def _has_Altivec(self): - return re.match(r'.*altivec\ supported.*', - self.info[0]['cpu']) is not None - - # Varia - - def _is_singleCPU(self): - return len(self.info) == 1 - - def _getNCPUs(self): - return len(self.info) - - def _has_fdiv_bug(self): - return self.info[0]['fdiv_bug'] == 'yes' - - def _has_f00f_bug(self): - return self.info[0]['f00f_bug'] == 'yes' - - def _has_mmx(self): - return re.match(r'.*?\bmmx\b', self.info[0]['flags']) is not None - - def _has_sse(self): - return re.match(r'.*?\bsse\b', self.info[0]['flags']) is not None - - def _has_sse2(self): - return re.match(r'.*?\bsse2\b', self.info[0]['flags']) is not None - - def _has_sse3(self): - return re.match(r'.*?\bpni\b', self.info[0]['flags']) is not None - - def _has_ssse3(self): - return re.match(r'.*?\bssse3\b', self.info[0]['flags']) is not None - - def _has_3dnow(self): - return re.match(r'.*?\b3dnow\b', self.info[0]['flags']) is not None - - def _has_3dnowext(self): - return re.match(r'.*?\b3dnowext\b', self.info[0]['flags']) is not None - - -class IRIXCPUInfo(CPUInfoBase): - info = None - - def __init__(self): - if self.info is not None: - return - info = key_value_from_command('sysconf', sep=' ', - successful_status=(0, 1)) - self.__class__.info = info - - def _not_impl(self): - pass - - def _is_singleCPU(self): - return self.info.get('NUM_PROCESSORS') == '1' - - def _getNCPUs(self): - return int(self.info.get('NUM_PROCESSORS', 1)) - - def __cputype(self, n): - return self.info.get('PROCESSORS').split()[0].lower() == 'r%s' % (n) - - def _is_r2000(self): - return self.__cputype(2000) - - def _is_r3000(self): - return self.__cputype(3000) - - def _is_r3900(self): - return self.__cputype(3900) - - def _is_r4000(self): - return self.__cputype(4000) - - def _is_r4100(self): - return self.__cputype(4100) - - def _is_r4300(self): - return self.__cputype(4300) - - def _is_r4400(self): - return self.__cputype(4400) - - def _is_r4600(self): - return self.__cputype(4600) - - def _is_r4650(self): - return self.__cputype(4650) - - def _is_r5000(self): - return self.__cputype(5000) - - def _is_r6000(self): - return self.__cputype(6000) - - def _is_r8000(self): - return self.__cputype(8000) - - def _is_r10000(self): - return self.__cputype(10000) - - def _is_r12000(self): - return self.__cputype(12000) - - def _is_rorion(self): - return self.__cputype('orion') - - def get_ip(self): - try: - return self.info.get('MACHINE') - except: - pass - - def __machine(self, n): - return self.info.get('MACHINE').lower() == 'ip%s' % (n) - - def _is_IP19(self): - return self.__machine(19) - - def _is_IP20(self): - return self.__machine(20) - - def _is_IP21(self): - return self.__machine(21) - - def _is_IP22(self): - return self.__machine(22) - - def _is_IP22_4k(self): - return self.__machine(22) and self._is_r4000() - - def _is_IP22_5k(self): - return self.__machine(22) and self._is_r5000() - - def _is_IP24(self): - return self.__machine(24) - - def _is_IP25(self): - return self.__machine(25) - - def _is_IP26(self): - return self.__machine(26) - - def _is_IP27(self): - return self.__machine(27) - - def _is_IP28(self): - return self.__machine(28) - - def _is_IP30(self): - return self.__machine(30) - - def _is_IP32(self): - return self.__machine(32) - - def _is_IP32_5k(self): - return self.__machine(32) and self._is_r5000() - - def _is_IP32_10k(self): - return self.__machine(32) and self._is_r10000() - - -class DarwinCPUInfo(CPUInfoBase): - info = None - - def __init__(self): - if self.info is not None: - return - info = command_info(arch='arch', - machine='machine') - info['sysctl_hw'] = key_value_from_command(['sysctl', 'hw'], sep='=') - self.__class__.info = info - - def _not_impl(self): pass - - def _getNCPUs(self): - return int(self.info['sysctl_hw'].get('hw.ncpu', 1)) - - def _is_Power_Macintosh(self): - return self.info['sysctl_hw']['hw.machine'] == 'Power Macintosh' - - def _is_i386(self): - return self.info['arch'] == 'i386' - - def _is_ppc(self): - return self.info['arch'] == 'ppc' - - def __machine(self, n): - return self.info['machine'] == 'ppc%s' % n - - def _is_ppc601(self): return self.__machine(601) - - def _is_ppc602(self): return self.__machine(602) - - def _is_ppc603(self): return self.__machine(603) - - def _is_ppc603e(self): return self.__machine('603e') - - def _is_ppc604(self): return self.__machine(604) - - def _is_ppc604e(self): return self.__machine('604e') - - def _is_ppc620(self): return self.__machine(620) - - def _is_ppc630(self): return self.__machine(630) - - def _is_ppc740(self): return self.__machine(740) - - def _is_ppc7400(self): return self.__machine(7400) - - def _is_ppc7450(self): return self.__machine(7450) - - def _is_ppc750(self): return self.__machine(750) - - def _is_ppc403(self): return self.__machine(403) - - def _is_ppc505(self): return self.__machine(505) - - def _is_ppc801(self): return self.__machine(801) - - def _is_ppc821(self): return self.__machine(821) - - def _is_ppc823(self): return self.__machine(823) - - def _is_ppc860(self): return self.__machine(860) - -class NetBSDCPUInfo(CPUInfoBase): - info = None - - def __init__(self): - if self.info is not None: - return - info = {} - info['sysctl_hw'] = key_value_from_command(['sysctl', 'hw'], sep='=') - info['arch'] = info['sysctl_hw'].get('hw.machine_arch', 1) - info['machine'] = info['sysctl_hw'].get('hw.machine', 1) - self.__class__.info = info - - def _not_impl(self): pass - - def _getNCPUs(self): - return int(self.info['sysctl_hw'].get('hw.ncpu', 1)) - - def _is_Intel(self): - if self.info['sysctl_hw'].get('hw.model', "")[0:5] == 'Intel': - return True - return False - - def _is_AMD(self): - if self.info['sysctl_hw'].get('hw.model', "")[0:3] == 'AMD': - return True - return False - -class SunOSCPUInfo(CPUInfoBase): - info = None - - def __init__(self): - if self.info is not None: - return - info = command_info(arch='arch', - mach='mach', - uname_i=['uname', '-i'], - isainfo_b=['isainfo', '-b'], - isainfo_n=['isainfo', '-n'], - ) - info['uname_X'] = key_value_from_command(['uname', '-X'], sep='=') - for line in command_by_line(['psrinfo', '-v', '0']): - m = re.match(r'\s*The (?P

[\w\d]+) processor operates at', line) - if m: - info['processor'] = m.group('p') - break - self.__class__.info = info - - def _not_impl(self): - pass - - def _is_i386(self): - return self.info['isainfo_n'] == 'i386' - - def _is_sparc(self): - return self.info['isainfo_n'] == 'sparc' - - def _is_sparcv9(self): - return self.info['isainfo_n'] == 'sparcv9' - - def _getNCPUs(self): - return int(self.info['uname_X'].get('NumCPU', 1)) - - def _is_sun4(self): - return self.info['arch'] == 'sun4' - - def _is_SUNW(self): - return re.match(r'SUNW', self.info['uname_i']) is not None - - def _is_sparcstation5(self): - return re.match(r'.*SPARCstation-5', self.info['uname_i']) is not None - - def _is_ultra1(self): - return re.match(r'.*Ultra-1', self.info['uname_i']) is not None - - def _is_ultra250(self): - return re.match(r'.*Ultra-250', self.info['uname_i']) is not None - - def _is_ultra2(self): - return re.match(r'.*Ultra-2', self.info['uname_i']) is not None - - def _is_ultra30(self): - return re.match(r'.*Ultra-30', self.info['uname_i']) is not None - - def _is_ultra4(self): - return re.match(r'.*Ultra-4', self.info['uname_i']) is not None - - def _is_ultra5_10(self): - return re.match(r'.*Ultra-5_10', self.info['uname_i']) is not None - - def _is_ultra5(self): - return re.match(r'.*Ultra-5', self.info['uname_i']) is not None - - def _is_ultra60(self): - return re.match(r'.*Ultra-60', self.info['uname_i']) is not None - - def _is_ultra80(self): - return re.match(r'.*Ultra-80', self.info['uname_i']) is not None - - def _is_ultraenterprice(self): - return re.match(r'.*Ultra-Enterprise', self.info['uname_i']) is not None - - def _is_ultraenterprice10k(self): - return re.match(r'.*Ultra-Enterprise-10000', self.info['uname_i']) is not None - - def _is_sunfire(self): - return re.match(r'.*Sun-Fire', self.info['uname_i']) is not None - - def _is_ultra(self): - return re.match(r'.*Ultra', self.info['uname_i']) is not None - - def _is_cpusparcv7(self): - return self.info['processor'] == 'sparcv7' - - def _is_cpusparcv8(self): - return self.info['processor'] == 'sparcv8' - - def _is_cpusparcv9(self): - return self.info['processor'] == 'sparcv9' - - -class Win32CPUInfo(CPUInfoBase): - info = None - pkey = r"HARDWARE\DESCRIPTION\System\CentralProcessor" - # XXX: what does the value of - # HKEY_LOCAL_MACHINE\HARDWARE\DESCRIPTION\System\CentralProcessor\0 - # mean? - - def __init__(self): - try: - import _winreg - except ImportError: # Python 3 - import winreg as _winreg - - if self.info is not None: - return - info = [] - try: - #XXX: Bad style to use so long `try:...except:...`. Fix it! - - prgx = re.compile(r"family\s+(?P\d+)\s+model\s+(?P\d+)" - r"\s+stepping\s+(?P\d+)", re.IGNORECASE) - chnd = _winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE, self.pkey) - pnum = 0 - while 1: - try: - proc = _winreg.EnumKey(chnd, pnum) - except _winreg.error: - break - else: - pnum += 1 - info.append({"Processor": proc}) - phnd = _winreg.OpenKey(chnd, proc) - pidx = 0 - while True: - try: - name, value, _ = _winreg.EnumValue(phnd, pidx) - except _winreg.error: - break - else: - pidx = pidx + 1 - info[-1][name] = value - if name == "Identifier": - srch = prgx.search(value) - if srch: - info[-1]["Family"] = int(srch.group("FML")) - info[-1]["Model"] = int(srch.group("MDL")) - info[-1]["Stepping"] = int(srch.group("STP")) - except: - print(sys.exc_value, '(ignoring)') - self.__class__.info = info - - def _not_impl(self): - pass - - # Athlon - - def _is_AMD(self): - return self.info[0]['VendorIdentifier'] == 'AuthenticAMD' - - def _is_Am486(self): - return self.is_AMD() and self.info[0]['Family'] == 4 - - def _is_Am5x86(self): - return self.is_AMD() and self.info[0]['Family'] == 4 - - def _is_AMDK5(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] in [0, 1, 2, 3]) - - def _is_AMDK6(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] in [6, 7]) - - def _is_AMDK6_2(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 8) - - def _is_AMDK6_3(self): - return (self.is_AMD() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 9) - - def _is_AMDK7(self): - return self.is_AMD() and self.info[0]['Family'] == 6 - - # To reliably distinguish between the different types of AMD64 chips - # (Athlon64, Operton, Athlon64 X2, Semperon, Turion 64, etc.) would - # require looking at the 'brand' from cpuid - - def _is_AMD64(self): - return self.is_AMD() and self.info[0]['Family'] == 15 - - # Intel - - def _is_Intel(self): - return self.info[0]['VendorIdentifier'] == 'GenuineIntel' - - def _is_i386(self): - return self.info[0]['Family'] == 3 - - def _is_i486(self): - return self.info[0]['Family'] == 4 - - def _is_i586(self): - return self.is_Intel() and self.info[0]['Family'] == 5 - - def _is_i686(self): - return self.is_Intel() and self.info[0]['Family'] == 6 - - def _is_Pentium(self): - return self.is_Intel() and self.info[0]['Family'] == 5 - - def _is_PentiumMMX(self): - return (self.is_Intel() and self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 4) - - def _is_PentiumPro(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] == 1) - - def _is_PentiumII(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [3, 5, 6]) - - def _is_PentiumIII(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [7, 8, 9, 10, 11]) - - def _is_PentiumIV(self): - return self.is_Intel() and self.info[0]['Family'] == 15 - - def _is_PentiumM(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [9, 13, 14]) - - def _is_Core2(self): - return (self.is_Intel() and self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [15, 16, 17]) - - # Varia - - def _is_singleCPU(self): - return len(self.info) == 1 - - def _getNCPUs(self): - return len(self.info) - - def _has_mmx(self): - if self.is_Intel(): - return ((self.info[0]['Family'] == 5 and - self.info[0]['Model'] == 4) or - (self.info[0]['Family'] in [6, 15])) - elif self.is_AMD(): - return self.info[0]['Family'] in [5, 6, 15] - else: - return False - - def _has_sse(self): - if self.is_Intel(): - return ((self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [7, 8, 9, 10, 11]) or - self.info[0]['Family'] == 15) - elif self.is_AMD(): - return ((self.info[0]['Family'] == 6 and - self.info[0]['Model'] in [6, 7, 8, 10]) or - self.info[0]['Family'] == 15) - else: - return False - - def _has_sse2(self): - if self.is_Intel(): - return self.is_Pentium4() or self.is_PentiumM() or self.is_Core2() - elif self.is_AMD(): - return self.is_AMD64() - else: - return False - - def _has_3dnow(self): - return self.is_AMD() and self.info[0]['Family'] in [5, 6, 15] - - def _has_3dnowext(self): - return self.is_AMD() and self.info[0]['Family'] in [6, 15] - - -if sys.platform.startswith('linux'): # variations: linux2,linux-i386 (any others?) - cpuinfo = LinuxCPUInfo -elif sys.platform.startswith('irix'): - cpuinfo = IRIXCPUInfo -elif sys.platform == 'darwin': - cpuinfo = DarwinCPUInfo -elif sys.platform[0:6] == 'netbsd': - cpuinfo = NetBSDCPUInfo -elif sys.platform.startswith('sunos'): - cpuinfo = SunOSCPUInfo -elif sys.platform.startswith('win32'): - cpuinfo = Win32CPUInfo -elif sys.platform.startswith('cygwin'): - cpuinfo = LinuxCPUInfo -#XXX: other OS's. Eg. use _winreg on Win32. Or os.uname on unices. -else: - cpuinfo = CPUInfoBase diff --git a/chameleon/tools/custom_aug.py b/chameleon/tools/custom_aug.py deleted file mode 100644 index b16cbd5..0000000 --- a/chameleon/tools/custom_aug.py +++ /dev/null @@ -1,99 +0,0 @@ -import math -import random -from typing import Tuple - -import albumentations as A -import cv2 -import numpy as np -from PIL import Image - -from .mixin import BorderValueMixin, FillValueMixin - -__all__ = [ - 'RandomSunFlare', 'CoarseDropout', 'ShiftScaleRotate', 'SaftRotate', - 'Perspective', 'Shear', 'Rotate180', -] - - -class RandomSunFlare(A.RandomSunFlare): - - @property - def src_radius(self): - return random.randint(50, 200) - - @src_radius.setter - def src_radius(self, x): - return None - - -class CoarseDropout(FillValueMixin, A.CoarseDropout): - ... - - -class ShiftScaleRotate(BorderValueMixin, A.ShiftScaleRotate): - ... - - -class SaftRotate(BorderValueMixin, A.SafeRotate): - ... - - -class Perspective(BorderValueMixin, A.Perspective): - ... - - -class Shear: - - def __init__(self, max_shear: Tuple[int, int] = (20, 20), p: float = 0.5): - self.p = p - self.max_shear_left, self.max_shear_right = max_shear - - def __call__(self, img): - if np.random.rand() < self.p: - height, width, *_ = img.shape - img = Image.fromarray(img) - - angle_to_shear = int( - np.random.uniform(-self.max_shear_left - 1, self.max_shear_right + 1)) - if angle_to_shear != -1: - angle_to_shear += 1 - - phi = math.tan(math.radians(angle_to_shear)) - shift_in_pixels = phi * height - shift_in_pixels = math.ceil(shift_in_pixels) \ - if shift_in_pixels > 0 else math.floor(shift_in_pixels) - - matrix_offset = shift_in_pixels - if angle_to_shear <= 0: - shift_in_pixels = abs(shift_in_pixels) - matrix_offset = 0 - phi = abs(phi) * -1 - - transform_matrix = (1, phi, -matrix_offset, 0, 1, 0) - img = img.transform((int(round(width + shift_in_pixels)), height), - Image.AFFINE, - transform_matrix, - Image.BICUBIC) - - img = img.crop((abs(shift_in_pixels), 0, width, height)) - img = cv2.resize(np.array(img), (width, height)) - - return img - - -class Rotate180: - - def __init__(self, p: float = 0.5): - self.p = p - self.rotate180 = A.Compose([ - A.HorizontalFlip(p=1), - A.VerticalFlip(p=1), - ], p=1) - - def __call__(self, **kwargs): - is_rotate = 0 - if np.random.rand() < self.p: - results = self.rotate180(**kwargs) - is_rotate = 1 - results.update({'is_rotate': is_rotate}) - return results diff --git a/chameleon/tools/mixin.py b/chameleon/tools/mixin.py deleted file mode 100644 index de4c274..0000000 --- a/chameleon/tools/mixin.py +++ /dev/null @@ -1,51 +0,0 @@ -import random - -import cv2 - -__all__ = [ - 'BorderValueMixin', 'FillValueMixin', -] - - -class BorderValueMixin: - - @property - def pad_mode(self): - return random.choice([ - cv2.BORDER_CONSTANT, - cv2.BORDER_REPLICATE, - ]) - - @property - def border_mode(self): - return random.choice([ - cv2.BORDER_CONSTANT, - cv2.BORDER_REPLICATE, - ]) - - @property - def value(self): - return [random.randint(0, 255) for _ in range(3)] - - @pad_mode.setter - def pad_mode(self, x): - return None - - @border_mode.setter - def border_mode(self, x): - return None - - @value.setter - def value(self, x): - return None - - -class FillValueMixin: - - @property - def fill_value(self): - return [random.randint(0, 255) for _ in range(3)] - - @fill_value.setter - def fill_value(self, x): - return None diff --git a/chameleon/tools/model_profile.py b/chameleon/tools/model_profile.py deleted file mode 100644 index 2892b9d..0000000 --- a/chameleon/tools/model_profile.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Dict, Union - -import torch -from calflops import calculate_flops -from ptflops import get_model_complexity_info - -from .cpuinfo import cpuinfo - -__all__ = ['get_model_complexity_info', - 'get_cpu_gflops', 'get_meta_info', 'calculate_flops'] - - -def get_cpu_gflops(one_cpu_core: bool = True) -> float: - _cpuinfo = cpuinfo() - ghz = float(_cpuinfo.info[0]['cpu MHz']) * 10e-3 - core = 1 if one_cpu_core else int(_cpuinfo.info[0]['cpu cores']) - gflops = ghz * core * 10e9 - return gflops - - -def get_meta_info(macs: float, params: int, one_cpu_core: bool = True) -> dict: - return { - 'Params(M)': f"{params/1e6:.3f}", - 'MACs(G)': f"{macs/1e9:.3f}", - 'FLOPs(G)': f"{(macs * 2)/1e9:.3f}", - 'ModelSize_FP32 (MB)': f"{params * 4 / 1e6:.3f}", - 'CPU infos': { - 'cpu_model_name': cpuinfo().info[0]['model name'], - 'cpu_cores': cpuinfo().info[0]['cpu cores'], - 'infer_time (ms) (*rough estimate*)': f"{(macs * 2) * 1000 / get_cpu_gflops(one_cpu_core):.3f}", - } - } - - -def profile_model( - model: Union[torch.nn.Module, str], - input_shape: tuple = (1, 3, 224, 224), - output_as_string: bool = False, - output_precision: int = 4, - print_detailed: bool = False, - features_only: bool = True, - one_cpu_core: bool = True -) -> Dict[str, str]: - """ - Profile a model to get its meta data. - - Args: - model (Union[torch.nn.Module, str]): Model to be profiled. If a string is given, it will be treated as the model name by timm library. - input_shape (tuple): Input shape of the model. Default: (1, 3, 224, 224). - output_as_string (bool): Whether to output the results as string. Default: False. - output_precision (int): Precision of the output. Default: 4. - print_detailed (bool): Whether to print detailed information. Default: False. - features_only (bool): Whether to calculate only the features. Default: True. - one_cpu_core (bool): Whether to use only one CPU core. Default: True. - - Returns: - Dict[str, str]: Meta data of the model. - """ - - if isinstance(model, str): - - import timm - - model = timm.create_model( - model, - pretrained=False, - features_only=features_only - ) - - _, macs, params = calculate_flops( - model, - input_shape=input_shape, - output_as_string=output_as_string, - output_precision=output_precision, - print_detailed=print_detailed - ) - - meta_data = get_meta_info(macs, params, one_cpu_core=one_cpu_core) - - return meta_data diff --git a/chameleon/tools/replace.py b/chameleon/tools/replace.py deleted file mode 100644 index 368f4ff..0000000 --- a/chameleon/tools/replace.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any, Union - -import torch.nn as nn - -from ..nn import build_nn, build_nn_cls - -__all__ = ['has_children', 'replace_module', 'replace_module_attr_value'] - - -def has_children(module): - try: - next(module.children()) - return True - except StopIteration: - return False - - -def replace_module( - model: nn.Module, - target: Union[type, str], - dst_module: Union[nn.Module, dict] -) -> None: - """ - Function to replace modules. - - Args: - model (nn.Module): - NN module. - target (Union[type, str]): - The type of module you want to replace. - dst_module (Union[nn.Module, dict]): - The module you want to use after replacement. - """ - if not isinstance(dst_module, (nn.Module, dict)): - raise ValueError(f'dst_module = {dst_module} should be an instance of Module or dict.') - - target = build_nn_cls(target) if isinstance(target, str) else target - dst_module = build_nn(**dst_module) if isinstance(dst_module, dict) else dst_module - - for name, m in model.named_children(): - if has_children(m): - replace_module(m, target, dst_module) - else: - if isinstance(m, target): - setattr(model, name, dst_module) - - -def replace_module_attr_value( - model: nn.Module, - target: Union[type, str], - attr_name: str, - attr_value: Any -) -> None: - """ - Function to replace attr's value in target module - - Args: - model (nn.Module): NN module. - target (Union[type, str]): The type of module you want to modify. - attr_name (str): The name of the attribute you want to modify. - attr_value (Any): The new value of the attribute. - """ - target = build_nn_cls(target) if isinstance(target, str) else target - for module in model.modules(): - if isinstance(module, target): - setattr(module, attr_name, attr_value) diff --git a/chameleon/transformers/__init__.py b/chameleon/transformers/__init__.py deleted file mode 100644 index cbc3b70..0000000 --- a/chameleon/transformers/__init__.py +++ /dev/null @@ -1,108 +0,0 @@ -import fnmatch -from functools import partial - -from .basic import ImageEncoder, ImageEncoderLayer -from .efficientformer import EfficientFormer -from .metaformer import MetaFormer, MlpBlock -from .mobilevit import MobileViT -from .poolformer import PoolFormer -from .token_mixer import (Attention, AttentionMixing, PoolMixing, RandomMixing, - SepConvMixing) -from .utils import calculate_patch_size, list_models_transformers -from .vit import ViT - -__all__ = [ - 'ViT', 'calculate_patch_size', 'list_models_transformers', 'MobileViT', - 'PoolFormer', 'MetaFormer', 'MlpBlock', 'build_transformer', 'list_transformer', - 'Attention', 'AttentionMixing', 'PoolMixing', 'RandomMixing', 'SepConvMixing', - 'ImageEncoder', 'ImageEncoderLayer', 'EfficientFormer', -] - -BASE_TRANSFORMER_NAMES = { - 'vit': ViT, - 'mobilevit': MobileViT, - 'poolformer': PoolFormer, - 'metaformer': MetaFormer, - 'efficientformer': EfficientFormer, -} - -VIT_NAMES = [ - 'vit-base-patch16-224-in21k', - 'vit-base-patch16-224', - 'vit-base-patch16-384', - 'vit-base-patch32-224-in21k', - 'vit-base-patch32-384', - 'vit-huge-patch14-224-in21k', - 'vit-large-patch16-224-in21k', - 'vit-large-patch16-224', - 'vit-large-patch16-384', - 'vit-large-patch32-224-in21k', - 'vit-large-patch32-384', - 'vit-hybrid-base-bit-384', -] - -MOBILEVIT_NAMES = [ - 'mobilevit-small', - 'mobilevit-x-small', - 'mobilevit-xx-small', - 'deeplabv3-mobilevit-small', - 'deeplabv3-mobilevit-x-small', - 'deeplabv3-mobilevit-xx-small', -] - -POOLFORMER_NAMES = [ - 'poolformer_m36', - 'poolformer_m48', - 'poolformer_s12', - 'poolformer_s24', - 'poolformer_s36', -] - -METAFORMER_NAMES = [ - 'poolformer_v2_tiny', - 'poolformer_v2_small', - 'poolformer_v2_s12', - 'poolformer_v2_s24', - 'poolformer_v2_s36', - 'poolformer_v2_m36', - 'poolformer_v2_m48', - 'convformer_s18', - 'convformer_s36', - 'convformer_m36', - 'convformer_b36', - 'caformer_tiny', - 'caformer_small', - 'caformer_s18', - 'caformer_s36', - 'caformer_m36', - 'caformer_b36', -] - -EFFICIENTFORMER_NAMES = [ - 'efficientformer-l1-300', - 'efficientformer-l3-300', - 'efficientformer-l7-300', -] - -TRANSFORMER = { - **{name: module for name, module in BASE_TRANSFORMER_NAMES.items()}, - **{name: partial(ViT.from_pretrained, name=f'google/{name}') for name in VIT_NAMES}, - **{name: partial(MobileViT.from_pretrained, name=f'apple/{name}') for name in MOBILEVIT_NAMES}, - **{name: partial(PoolFormer.from_pretrained, name=f'sail/{name}') for name in POOLFORMER_NAMES}, - **{name: partial(MetaFormer.from_pretrained, name=name) for name in METAFORMER_NAMES}, - **{name: partial(EfficientFormer.from_pretrained, name=f'snap-research/{name}') for name in EFFICIENTFORMER_NAMES}, -} - - -def build_transformer(name: str, **kwargs): - if name not in TRANSFORMER: - raise ValueError(f'Transformer={name} is not supported.') - return TRANSFORMER[name](**kwargs) - - -def list_transformer(filter=''): - model_list = list(TRANSFORMER.keys()) - if len(filter): - return fnmatch.filter(model_list, filter) # include these models - else: - return model_list diff --git a/chameleon/transformers/basic.py b/chameleon/transformers/basic.py deleted file mode 100644 index f0f656a..0000000 --- a/chameleon/transformers/basic.py +++ /dev/null @@ -1,127 +0,0 @@ -import math -from typing import Tuple, Union - -import torch -import torch.nn as nn - -from ..nn.components import build_activation -from .token_mixer import SelfAttention - -__all__ = ['ImageEncoder', 'ImageEncoderLayer'] - - -class ImageEncoderLayer(nn.Module): - - def __init__( - self, - d_model: int, - nhead: int = 8, - expand_ratio: float = 2, - norm_first: bool = True, - inner_act: Union[dict, nn.Module] = {'name': 'StarReLU'}, - ) -> None: - """ - Initializes the EncoderLayer. - - Args: - d_model (int): - The number of input dimensions. - nhead (int): - The number of attention heads. - expand_ratio (float, optional): - The expansion ratio for the hidden dimensions. - Defaults to 2. - norm_first (bool, optional): - Whether to apply the normalization before the attention layer. - Defaults to True. - inner_act (Union[dict, nn.Module], optional): - The activation function to use for the inner feedforward layer. - Defaults to {'name': 'StarReLU'}. - """ - super().__init__() - hidden_dims = int(d_model * expand_ratio) - self.ffn = nn.Sequential( - nn.Linear(d_model, hidden_dims), - inner_act if isinstance( - inner_act, nn.Module) else build_activation(**inner_act), - nn.Linear(hidden_dims, d_model), - ) - self.att = SelfAttention(embed_dim=d_model, num_heads=nhead) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm_first = norm_first - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.norm_first: - norm_x = self.norm1(x) - att, att_weights = self.att(norm_x, norm_x, norm_x) - x = x + att - x = x + self.ffn(self.norm2(x)) - else: - att, att_weights = self.att(x, x, x) - x = self.norm1(x + att) - x = self.norm2(x + self.ffn(x)) - return x, att_weights - - -class ImageEncoder(nn.Module): - - def __init__( - self, - d_model: int, - num_layers: int, - image_size: Union[int, Tuple[int, int]], - patch_size: Union[int, Tuple[int, int]] = 16, - in_c: int = 3, - *args, **kwargs, - ) -> None: - """ - Initialize a ImageEncoder module. - - Args: - d_model (int): - The input dimension of the encoder. - num_layers (int): - The number of layers in the encoder. - image_size (Union[int, Tuple[int, int]]): - The input image size. - patch_size (Union[int, Tuple[int, int]], optional): - The patch size. Defaults to 16. - in_c (int): - The number of input channels. Defaults to 3. - """ - super().__init__() - h, w = image_size if isinstance( - image_size, (tuple, list)) else (image_size, image_size) - ph, pw = patch_size if isinstance( - patch_size, (tuple, list)) else (patch_size, patch_size) - nh, nw = h // ph, w // pw - - self.cls_token = nn.Parameter(torch.Tensor(1, 1, d_model)) - self.pos_emb = nn.Parameter(torch.Tensor(1, nh*nw, d_model)) - nn.init.kaiming_uniform_(self.cls_token, a=math.sqrt(5)) - nn.init.kaiming_uniform_(self.pos_emb, a=math.sqrt(5)) - - self.tokenizer = nn.Conv2d( - in_c, d_model, (ph, pw), (ph, pw), bias=False) - self.encoder = nn.ModuleList([ - ImageEncoderLayer(d_model, *args, **kwargs) - for _ in range(num_layers) - ]) - - def forward(self, x: torch.Tensor, cls_token: torch.Tensor = None) -> torch.Tensor: - """ - Forward pass of the ImageEncoder. - """ - x = self.tokenizer(x) - x = x.flatten(2).transpose(1, 2) - x = x + self.pos_emb.expand(x.size(0), -1, -1) - if cls_token is None: - cls_token = self.cls_token.expand(x.size(0), -1, -1) - x = torch.cat((cls_token, x), dim=1) - att_weights = [] - for layer in self.encoder: - x, _att_weights = layer(x) - att_weights.append(_att_weights) - cls_token, hidden = torch.split(x, (1, x.size(1)-1), dim=1) - return cls_token.squeeze(1), hidden, att_weights diff --git a/chameleon/transformers/efficientformer.py b/chameleon/transformers/efficientformer.py deleted file mode 100644 index 2a80ff4..0000000 --- a/chameleon/transformers/efficientformer.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import List - -import torch -import torch.nn as nn -from transformers import EfficientFormerConfig, EfficientFormerModel - -from .utils import list_models_transformers - -__all__ = ['EfficientFormer'] - - -class EfficientFormer(nn.Module): - - def __init__( - self, - depths: List[int] = [3, 2, 6, 4], - hidden_sizes: List[int] = [48, 96, 224, 448], - downsamples: List[bool] = [True, True, True, True], - dim: int = 448, - key_dim: int = 32, - attention_ratio: int = 4, - resolution: int = 7, - num_hidden_layers: int = 5, - num_attention_heads: int = 8, - mlp_expansion_ratio: int = 4, - hidden_dropout_prob: float = 0.0, - patch_size: int = 16, - num_channels: int = 3, - pool_size: int = 3, - downsample_patch_size: int = 3, - downsample_stride: int = 2, - downsample_pad: int = 1, - drop_path_rate: float = 0.0, - num_meta3d_blocks: int = 1, - distillation: bool = True, - use_layer_scale: bool = True, - layer_scale_init_value: float = 1e-5, - hidden_act: str = "gelu", - initializer_range: float = 0.02, - layer_norm_eps: float = 1e-12, - **kwargs, - ) -> None: - r""" - This is the configuration class to store the configuration of an - [`EfficientFormerModel`]. It is used to instantiate an EfficientFormer - model according to the specified arguments, defining the model architecture. - - Instantiating a configuration with the defaults will yield a similar - configuration to that of the EfficientFormer [snap-research/efficientformer-l1] - (https://huggingface.co/snap-research/efficientformer-l1) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used - to control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. - - Args: - depths (`List(int)`, *optional*, defaults to `[3, 2, 6, 4]`) - Depth of each stage. - hidden_sizes (`List(int)`, *optional*, defaults to `[48, 96, 224, 448]`) - Dimensionality of each stage. - downsamples (`List(bool)`, *optional*, defaults to `[True, True, True, True]`) - Whether or not to downsample inputs between two stages. - dim (`int`, *optional*, defaults to 448): - Number of channels in Meta3D layers - key_dim (`int`, *optional*, defaults to 32): - The size of the key in meta3D block. - attention_ratio (`int`, *optional*, defaults to 4): - Ratio of the dimension of the query and value to the dimension - of the key in MSHA block - resolution (`int`, *optional*, defaults to 5) - Size of each patch - num_hidden_layers (`int`, *optional*, defaults to 5): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 8): - Number of attention heads for each attention layer in the 3D - MetaBlock. - mlp_expansion_ratio (`int`, *optional*, defaults to 4): - Ratio of size of the hidden dimensionality of an MLP to the - dimensionality of its input. - hidden_dropout_prob (`float`, *optional*, defaults to 0.1): - The dropout probability for all fully connected layers in the - embeddings and encoder. - patch_size (`int`, *optional*, defaults to 16): - The size (resolution) of each patch. - num_channels (`int`, *optional*, defaults to 3): - The number of input channels. - pool_size (`int`, *optional*, defaults to 3): - Kernel size of pooling layers. - downsample_patch_size (`int`, *optional*, defaults to 3): - The size of patches in downsampling layers. - downsample_stride (`int`, *optional*, defaults to 2): - The stride of convolution kernels in downsampling layers. - downsample_pad (`int`, *optional*, defaults to 1): - Padding in downsampling layers. - drop_path_rate (`int`, *optional*, defaults to 0): - Rate at which to increase dropout probability in DropPath. - num_meta3d_blocks (`int`, *optional*, defaults to 1): - The number of 3D MetaBlocks in the last stage. - distillation (`bool`, *optional*, defaults to `True`): - Whether to add a distillation head. - use_layer_scale (`bool`, *optional*, defaults to `True`): - Whether to scale outputs from token mixers. - layer_scale_init_value (`float`, *optional*, defaults to 1e-5): - Factor by which outputs from token mixers are scaled. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the - encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and - `"gelu_new"` are supported. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - layer_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the layer normalization layers. - """ - super().__init__() - self.config = EfficientFormerConfig( - depths=depths, - hidden_sizes=hidden_sizes, - downsamples=downsamples, - dim=dim, - key_dim=key_dim, - attention_ratio=attention_ratio, - resolution=resolution, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - mlp_expansion_ratio=mlp_expansion_ratio, - hidden_dropout_prob=hidden_dropout_prob, - patch_size=patch_size, - num_channels=num_channels, - pool_size=pool_size, - downsample_patch_size=downsample_patch_size, - downsample_stride=downsample_stride, - downsample_pad=downsample_pad, - drop_path_rate=drop_path_rate, - num_meta3d_blocks=num_meta3d_blocks, - distillation=distillation, - use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value, - hidden_act=hidden_act, - initializer_range=initializer_range, - layer_norm_eps=layer_norm_eps, - **kwargs, - ) - model = EfficientFormerModel(self.config) - self.model = self._model_clip(model) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - _, all_hidden_state = self.model(x, output_hidden_states=True, return_dict=False) - all_hidden_state = [all_hidden_state[i] for i in [1, 3, 5, 6]] - return all_hidden_state - - @staticmethod - def list_models(author='snap-research', search='efficientformer') -> List[str]: - return list_models_transformers(author=author, search=search) - - @classmethod - def from_pretrained(cls, name, **kwargs) -> 'EfficientFormer': - model = cls(**kwargs) - _model = EfficientFormerModel.from_pretrained(name, **kwargs) - model.model = cls._model_clip(_model) - return model - - @staticmethod - def _model_clip(m) -> None: - - class _Identity(nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, **kwargs): - return x - - m.flat = nn.Identity() - m.meta3D_layers = nn.Identity() - m.layernorm = nn.Identity() - m.encoder.last_stage = _Identity() - return m diff --git a/chameleon/transformers/metaformer.py b/chameleon/transformers/metaformer.py deleted file mode 100644 index 1d03751..0000000 --- a/chameleon/transformers/metaformer.py +++ /dev/null @@ -1,361 +0,0 @@ -from typing import List, Optional, Union -from warnings import warn - -import torch -import torch.nn as nn - -from ..nn import CNN2Dcell, LayerNorm2d, PowerModule, StarReLU -from .token_mixer import (AttentionMixing, PoolMixing, RandomMixing, - SepConvMixing) - -__all__ = ['MetaFormer', 'MetaFormerBlock', 'MlpBlock'] - - -MODEL_SETTINGS = { - 'poolformer_v2_tiny': { - 'depths': [1, 1, 3, 1], - 'hidden_sizes': [16, 32, 96, 128], - 'token_mixers': 'PoolMixing', - 'mlp_forwards': {'name': 'MlpBlock', 'expand_ratio': 1.5} - }, - 'poolformer_v2_small': { - 'depths': [2, 2, 4, 2], - 'hidden_sizes': [32, 64, 128, 256], - 'token_mixers': 'PoolMixing', - 'mlp_forwards': {'name': 'MlpBlock', 'expand_ratio': 1.5} - }, - 'poolformer_v2_s12': { - 'depths': [2, 2, 6, 2], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': 'PoolMixing', - }, - 'poolformer_v2_s24': { - 'depths': [4, 4, 12, 4], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': 'PoolMixing', - }, - 'poolformer_v2_s36': { - 'depths': [6, 6, 18, 6], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': 'PoolMixing', - }, - 'poolformer_v2_m36': { - 'depths': [6, 6, 18, 6], - 'hidden_sizes': [96, 192, 384, 768], - 'token_mixers': 'PoolMixing', - }, - 'poolformer_v2_m48': { - 'depths': [8, 8, 24, 8], - 'hidden_sizes': [96, 192, 384, 768], - 'token_mixers': 'PoolMixing', - }, - 'convformer_s18': { - 'depths': [3, 3, 9, 3], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': 'SepConvMixing', - }, - 'convformer_s36': { - 'depths': [3, 12, 18, 3], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': 'SepConvMixing', - }, - 'convformer_m36': { - 'depths': [3, 12, 18, 3], - 'hidden_sizes': [96, 192, 384, 576], - 'token_mixers': 'SepConvMixing', - }, - 'convformer_b36': { - 'depths': [3, 12, 18, 3], - 'hidden_sizes': [128, 256, 512, 768], - 'token_mixers': 'SepConvMixing', - }, - 'caformer_tiny': { - 'depths': [1, 1, 2, 1], - 'hidden_sizes': [16, 32, 64, 128], - 'token_mixers': ['SepConvMixing', 'SepConvMixing', 'AttentionMixing', 'AttentionMixing'], - 'mlp_forwards': {'name': 'MlpBlock', 'expand_ratio': 1.5} - }, - 'caformer_small': { - 'depths': [1, 1, 4, 2], - 'hidden_sizes': [16, 48, 128, 160], - 'token_mixers': ['SepConvMixing', 'SepConvMixing', 'AttentionMixing', 'AttentionMixing'], - 'mlp_forwards': {'name': 'MlpBlock', 'expand_ratio': 1.5} - }, - 'caformer_s18': { - 'depths': [3, 3, 9, 3], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': ['SepConvMixing', 'SepConvMixing', 'AttentionMixing', 'AttentionMixing'], - }, - 'caformer_s36': { - 'depths': [3, 12, 18, 3], - 'hidden_sizes': [64, 128, 320, 512], - 'token_mixers': ['SepConvMixing', 'SepConvMixing', 'AttentionMixing', 'AttentionMixing'], - }, - 'caformer_m36': { - 'depths': [3, 12, 18, 3], - 'hidden_sizes': [96, 192, 384, 576], - 'token_mixers': ['SepConvMixing', 'SepConvMixing', 'AttentionMixing', 'AttentionMixing'], - }, - 'caformer_b36': { - 'depths': [3, 12, 18, 3], - 'hidden_sizes': [128, 256, 512, 768], - 'token_mixers': ['SepConvMixing', 'SepConvMixing', 'AttentionMixing', 'AttentionMixing'], - }, -} - - -def build_token_mixer(name, **options) -> Union[nn.Module, None]: - cls = globals().get(name, None) - if cls is None: - raise ValueError(f'Token mixer named {name} is not supported.') - return cls(**options) - - -def build_mlps_forward(name, **options) -> Union[nn.Module, None]: - cls = globals().get(name, None) - if cls is None: - raise ValueError(f'MLP forward named {name} is not supported.') - return cls(**options) - - -class MlpBlock(nn.Module): - - def __init__( - self, - in_features: int, - out_features: int, - expand_ratio: float = 4 - ) -> None: - """ - MLP as used in MetaFormer models baslines and related networks. - - Args: - in_features: - The number of input features. - out_features: - The number of output features. - expand_ratio: - The multiplier applied to the number of input features to obtain - the number of hidden features. Defaults to 4. - """ - super().__init__() - hidden_features = int(expand_ratio * in_features) - self.fc1_block = nn.Conv2d(in_features, hidden_features, 1) - self.fc2_block = nn.Conv2d(hidden_features, out_features, 1) - self.act = StarReLU() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc1_block(x) - x = self.act(x) - x = self.fc2_block(x) - return x - - -class MetaFormerBlock(nn.Module): - - def __init__( - self, - in_features: int, - token_mixer: Union[str, dict] = None, - mlp_forward: Union[str, dict] = None, - ) -> None: - """ - A single block of the MetaFormer model, consisting of a weighted sum of - a token mixing module and an MLP. - - Args: - in_features (int): - The number of input features. - token_mixer (Union[dict, nn.Module], optional): - The token mixing module to use in the block. Can be either an - nn.Module instance or a dictionary specifying the token mixing - module to build using the `build_token_mixer` function. - Defaults to None. - mlp_forward (Union[dict, nn.Module], optional): - The MLP module to use in the block. Can be either an nn.Module - instance or a dictionary specifying the MLP module to build using - the `build_mlps_forward` function. - Defaults to None. - """ - - super().__init__() - self.in_features = in_features - self.token_mixer = self._build_token_mixers(token_mixer) - self.mlp_forward = self._build_mlp_forwars(mlp_forward) - self.norm_mixer = LayerNorm2d(in_features) - self.norm_mlp = LayerNorm2d(in_features) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.token_mixer(self.norm_mixer(x)) - x = x + self.mlp_forward(self.norm_mlp(x)) - return x - - def _build_mlp_forwars(self, param: Union[str, dict]) -> nn.Module: - if param is None: - return nn.Identity() - if isinstance(param, str): - if param == 'Identity': - return nn.Identity() - elif param == 'MlpBlock': - return MlpBlock( - in_features=self.in_features, - out_features=self.in_features, - ) - else: - raise ValueError(f'Unsupport mlp_forwards settings: {param}') - elif isinstance(param, dict): - if param['name'] in ['MlpBlock']: - param.update({ - 'in_features': self.in_features, - 'out_features': self.in_features - }) - return build_mlps_forward(**param) - - def _build_token_mixers(self, param: Union[str, dict]) -> nn.Module: - if param is None: - return nn.Identity() - if isinstance(param, str): - if param == 'AttentionMixing': - return AttentionMixing(self.in_features) - elif param == 'SepConvMixing': - return SepConvMixing(self.in_features) - elif param == 'PoolMixing': - return PoolMixing() - elif param == 'RandomMixing': - warn( - 'Do not use RandomMixing in MetaFormer by pass string name,' - 'to token_mixer, use `token_mixer={"name": "RandomMixing", "num_tokens": N}` instead.' - 'Set token_mixer to nn.Identity() instead.' - ) - return nn.Identity() - elif param == 'Identity': - return nn.Identity() - else: - raise ValueError(f'Unsupport token mixer settings: {param}') - elif isinstance(param, dict): - if param['name'] in ['AttentionMixing', 'SepConvMixing']: - param.update({'in_features': self.in_features}) - return build_token_mixer(**param) - -class MetaFormer(PowerModule): - - def __init__( - self, - num_channels: int = 3, - depths: List[int] = [2, 2, 6, 2], - hidden_sizes: List[int] = [64, 128, 320, 512], - patch_sizes: List[int] = [7, 3, 3, 3], - strides: List[int] = [4, 2, 2, 2], - padding: List[int] = [2, 1, 1, 1], - token_mixers: Union[dict, str, List[Union[dict, str]]] = 'PoolMixing', - mlp_forwards: Union[dict, str, List[Union[dict, str]]] = 'MlpBlock', - out_indices: Optional[List[int]] = None, - ) -> None: - """ - Initializes the MetaFormer model. - - Args: - num_channels (int, optional): - The number of channels in the input image. Defaults to 3. - depths (List[int], optional): - The number of blocks in each stage of the MetaFormer. - Defaults to [2, 2, 6, 2]. - hidden_sizes (List[int], optional): - The number of channels in each stage of the MetaFormer. - Defaults to [64, 128, 320, 512]. - patch_sizes (List[int], optional): - The patch size used in each stage of the MetaFormer. - Defaults to [7, 3, 3, 3]. - strides (List[int], optional): - The stride used in each stage of the MetaFormer. - Defaults to [4, 2, 2, 2]. - padding (List[int], optional): - The padding used in each stage of the MetaFormer. - Defaults to [2, 1, 1, 1]. - token_mixers (Union[dict, str, List[Union[dict, str]]], optional): - The token mixing modules used in the model. - Defaults to 'PoolMixing'. - mlp_forwards (Union[dict, str, List[Union[dict, str]]], optional): - The MLP modules used in the model. - Defaults to 'MlpBlock'. - out_indices (Optional[List[int]], optional): - The indices of the output feature maps. - Defaults to None. - """ - super().__init__() - - if not isinstance(depths, (list, tuple)): - raise ValueError('depths must be either list or tuple.') - - if not isinstance(hidden_sizes, (list, tuple)): - raise ValueError('hidden_sizes must be either list or tuple.') - - self.num_stage = len(depths) - - self.downsamples = nn.ModuleList([ - nn.Sequential( - LayerNorm2d(hidden_sizes[i - 1]) if i > 0 else nn.Identity(), - nn.Conv2d( - in_channels=num_channels if i == 0 else hidden_sizes[i-1], - out_channels=hidden_sizes[i], - kernel_size=ksize, - stride=s, - padding=p - ), - LayerNorm2d(hidden_sizes[i]) if i == 0 else nn.Identity(), - ) for i, (ksize, s, p) in enumerate(zip(patch_sizes, strides, padding)) - ]) - - token_mixers = [token_mixers] * self.num_stage \ - if not isinstance(token_mixers, (list, tuple)) else token_mixers - mlp_forwards = [mlp_forwards] * self.num_stage \ - if not isinstance(mlp_forwards, (list, tuple)) else mlp_forwards - - self.stages = nn.ModuleList([ - nn.Sequential(*[ - MetaFormerBlock( - in_features=hidden_sizes[i], - token_mixer=token_mixers[i], - mlp_forward=mlp_forwards[i], - ) - for _ in range(depth) - ]) - for i, depth in enumerate(depths) - ]) - - self.out_indices = out_indices - self.initialize_weights_() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - outs = [] - for i in range(self.num_stage): - x = self.downsamples[i](x) - x = self.stages[i](x) - outs.append(x) - - if self.out_indices is not None: - outs = [outs[i] for i in self.out_indices] - - return outs - - @classmethod - def from_pretrained(cls, name: str, **kwargs) -> 'MetaFormer': - """ - Initializes the MetaFormer model from the pretrained model. - - Args: - model_name (str): - The name of the pretrained model. - **kwargs: - The other arguments of the model. - - Returns: - MetaFormer: - The MetaFormer model. - """ - if name not in MODEL_SETTINGS: - raise ValueError(f'Unsupport model name: {name}') - - model_settings = MODEL_SETTINGS[name] - model_settings.update(kwargs) - return cls(**model_settings) diff --git a/chameleon/transformers/mobilevit.py b/chameleon/transformers/mobilevit.py deleted file mode 100644 index c429655..0000000 --- a/chameleon/transformers/mobilevit.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import List - -import torch -import torch.nn as nn -from transformers import MobileViTConfig, MobileViTModel - -from .utils import list_models_transformers - -__all__ = ['MobileViT'] - - -class MobileViT(nn.Module): - - def __init__( - self, - num_channels: int = 3, - image_size: int = 256, - patch_size: int = 2, - hidden_sizes: List[int] = [144, 192, 240], - neck_hidden_sizes: List[int] = [16, 32, 64, 96, 128, 160, 640], - num_attention_heads: int = 4, - mlp_ratio: float = 2.0, - expand_ratio: float = 4.0, - hidden_act: str = "relu", - conv_kernel_size: int = 3, - output_stride: int = 32, - hidden_dropout_prob: float = 0.1, - attention_probs_dropout_prob: float = 0.0, - classifier_dropout_prob: float = 0.1, - initializer_range: float = 0.02, - layer_norm_eps: float = 1e-5, - qkv_bias: bool = True, - aspp_out_channels: int = 256, - atrous_rates: List[int] = [6, 12, 18], - aspp_dropout_prob: float = 0.1, - semantic_loss_ignore_index: int = 255, - **kwargs, - ) -> None: - """ - This is the configuration of a `MobileViTModel`. It is used to instantiate - a MobileViT model according to the specified arguments, defining the model - architecture. Instantiating a configuration with the defaults will yield - a similar configuration to that of the MobileViT architecture. - - [apple/mobilevit-small](https://huggingface.co/apple/mobilevit-small) - - Args: - num_channels (int, optional): - The number of input channels. Defaults to 3. - image_size (int, optional): - The size (resolution) of each image. Defaults to 256. - patch_size (int, optional): - The size (resolution) of each patch. Defaults to 2. - hidden_sizes (List[int], optional): - Dimensionality (hidden size) of the Transformer encoders at each - stage. Defaults to [144, 192, 240] - neck_hidden_sizes (List[int], optional): - The number of channels for the feature maps of the backbone. - Defaults to [16, 32, 64, 96, 128, 160, 640] - num_attention_heads (int, optional): - Number of attention heads for each attention layer in the - Transformer encoder. Defaults to 4 - mlp_ratio (float, optional): - The ratio of the number of channels in the output of the MLP to - the number of channels in the input. Defaults to 2.0 - expand_ratio (float, optional): - Expansion factor for the MobileNetv2 layers. Defaults to 4.0. - hidden_act (str or function, optional): - The non-linear activation function (function or string) in the - Transformer encoder and convolution layers. Defaults to "relu". - conv_kernel_size (int, optional): - The size of the convolutional kernel in the MobileViT layer. - Defaults to 3. - output_stride (int, optional): - The ratio of the spatial resolution of the output to the - resolution of the input image. Defaults to 32. - hidden_dropout_prob (float, optional): - The dropout probabilitiy for all fully connected layers in the - Transformer encoder. Defaults to 0.1. - attention_probs_dropout_prob (float, optional): - The dropout ratio for the attention probabilities. Defaults to 0.0 - classifier_dropout_prob (float, optional): - The dropout ratio for attached classifiers. Defaults to 0.1. - initializer_range (float, optional): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. Defaults to 0.02. - layer_norm_eps (float, optional): - The epsilon used by the layer normalization layers. - Defaults to 1e-5. - qkv_bias (bool, optional): - Whether to add a bias to the queries, keys and values. - Defaults to True. - aspp_out_channels (int, optional): - Number of output channels used in the ASPP layer for semantic - segmentation. Defaults to 256. - atrous_rates (List[int], optional): - Dilation (atrous) factors used in the ASPP layer for semantic - segmentation. Defaults to [6, 12, 18]. - aspp_dropout_prob (float, optional): - The dropout ratio for the ASPP layer for semantic segmentation. - Defaults to 0.1. - semantic_loss_ignore_index (int, optional): - The index that is ignored by the loss function of the semantic - segmentation model. Defaults to 255. - """ - super().__init__() - self.config = MobileViTConfig( - num_channels=num_channels, - image_size=image_size, - patch_size=patch_size, - hidden_sizes=hidden_sizes, - neck_hidden_sizes=neck_hidden_sizes, - num_attention_heads=num_attention_heads, - mlp_ratio=mlp_ratio, - expand_ratio=expand_ratio, - hidden_act=hidden_act, - conv_kernel_size=conv_kernel_size, - output_stride=output_stride, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - classifier_dropout_prob=classifier_dropout_prob, - initializer_range=initializer_range, - layer_norm_eps=layer_norm_eps, - qkv_bias=qkv_bias, - aspp_out_channels=aspp_out_channels, - atrous_rates=atrous_rates, - aspp_dropout_prob=aspp_dropout_prob, - semantic_loss_ignore_index=semantic_loss_ignore_index, - **kwargs, - ) - self.model = MobileViTModel(self.config, expand_output=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - _, all_hidden_state = self.model(x, output_hidden_states=True, return_dict=False) - return all_hidden_state - - @staticmethod - def list_models(author='apple', search='mobilevit') -> List[str]: - return list_models_transformers(author=author, search=search) - - @classmethod - def from_pretrained(cls, name, **kwargs) -> 'MobileViT': - model = cls(**kwargs) - kwargs.update({'expand_output': False}) - model.model = MobileViTModel.from_pretrained(name, **kwargs) - return model diff --git a/chameleon/transformers/poolformer.py b/chameleon/transformers/poolformer.py deleted file mode 100644 index ec9b321..0000000 --- a/chameleon/transformers/poolformer.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import List - -import torch -import torch.nn as nn -from transformers import PoolFormerConfig, PoolFormerModel - -from .utils import list_models_transformers - -__all__ = ["PoolFormer"] - - -class PoolFormer(nn.Module): - - def __init__( - self, - num_channels: int = 3, - patch_size: int = 16, - stride: int = 16, - pool_size: int = 3, - mlp_ratio: float = 4.0, - depths: List[int] = [2, 2, 6, 2], - hidden_sizes: List[int] = [64, 128, 320, 512], - patch_sizes: List[int] = [7, 3, 3, 3], - strides: List[int] = [4, 2, 2, 2], - padding: List[int] = [2, 1, 1, 1], - num_encoder_blocks: int = 4, - drop_path_rate: float = 0.0, - hidden_act: str = 'relu', - use_layer_scale: bool = True, - layer_scale_init_value: float = 1e-5, - initializer_range: float = 0.02, - **kwargs: dict, - ) -> None: - """ - PoolFormer is a model that replaces attention token mixer in transfomrers - with extremely simple operator, pooling. - - Transformers have shown great potential in computer vision tasks. A common - belief is their attention-based token mixer module contributes most to - their competence. However, recent works show the attention-based module - in transformers can be replaced by spatial MLPs and the resulted models - still perform quite well. Based on this observation, we hypothesize that - the general architecture of the transformers, instead of the specific - token mixer module, is more essential to the model's performance. - - To verify this, we deliberately replace the attention module in transformers - with an embarrassingly simple spatial pooling operator to conduct only - the most basic token mixing. Surprisingly, we observe that the derived - model, termed as PoolFormer, achieves competitive performance on multiple - computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves - 82.1% top-1 accuracy, surpassing well-tuned vision transformer/MLP-like - baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer - parameters and 48%/60% fewer MACs. The effectiveness of PoolFormer - verifies our hypothesis and urges us to initiate the concept of "MetaFormer", - a general architecture abstracted from transformers without specifying - the token mixer. Based on the extensive experiments, we argue that - MetaFormer is the key player in achieving superior results for recent - transformer and MLP-like models on vision tasks. - - This work calls for more future research dedicated to improving MetaFormer - instead of focusing on the token mixer modules. Additionally, our proposed - PoolFormer could serve as a starting baseline for future MetaFormer - architecture design. - - Args: - num_channels (int, optional): - The number of channels in the input data. Defaults to 3. - patch_size (int, optional): - The size of the patches extracted from the input data. - Defaults to 16. - stride (int, optional): - The stride of the convolutional layer used to extract patches - from the input data. Defaults to 16. - pool_size (int, optional): - The size of the pooling kernel used in the PoolFormer encoder - layers. Defaults to 3. - mlp_ratio (float, optional): - The ratio of the hidden size in the feedforward layer of the - PoolFormer encoder to the input size. Defaults to 4.0. - depths (List[int], optional): - The number of blocks in each stage of the PoolFormer encoder. - Defaults to [2, 2, 6, 2]. - hidden_sizes (List[int], optional): - The size of the hidden layer in each block of the PoolFormer - encoder. Defaults to [64, 128, 320, 512]. - patch_sizes (List[int], optional): - The size of the convolutional kernel in each block of the - PoolFormer encoder. Defaults to [7, 3, 3, 3]. - strides (List[int], optional): - The stride of the convolutional layer in each block of the - PoolFormer encoder. Defaults to [4, 2, 2, 2]. - padding (List[int], optional): - The padding size of the convolutional layer in each block of the - PoolFormer encoder. Defaults to [2, 1, 1, 1]. - num_encoder_blocks (int, optional): - The number of encoder blocks in the PoolFormer encoder. - Defaults to 4. - drop_path_rate (float, optional): - The drop path rate used in the PoolFormer encoder. - Defaults to 0.0. - hidden_act (str, optional): - The activation function used in the PoolFormer encoder. - Defaults to "relu". - use_layer_scale (bool, optional): - Whether to use layer scaling in the PoolFormer encoder. - Defaults to True. - layer_scale_init_value (float, optional): - The initial value of the layer scale in the PoolFormer encoder. - Defaults to 1e-5. - initializer_range (float, optional): - The range of the uniform distribution used to initialize the - weights in the PoolFormer encoder. Defaults to 0.02. - """ - super().__init__() - self.config = PoolFormerConfig( - num_channels=num_channels, - patch_size=patch_size, - stride=stride, - pool_size=pool_size, - mlp_ratio=mlp_ratio, - depths=depths, - hidden_sizes=hidden_sizes, - patch_sizes=patch_sizes, - strides=strides, - padding=padding, - num_encoder_blocks=num_encoder_blocks, - drop_path_rate=drop_path_rate, - hidden_act=hidden_act, - use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value, - initializer_range=initializer_range, - **kwargs, - ) - self.model = PoolFormerModel(self.config) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - *_, all_hidden_state = self.model(x, output_hidden_states=True, return_dict=False) - return all_hidden_state - - @staticmethod - def list_models(author='sail', search='poolformer') -> List[str]: - return list_models_transformers(author=author, search=search) - - @classmethod - def from_pretrained(cls, name, **kwargs) -> 'PoolFormer': - model = cls(**kwargs) - model.model = PoolFormerModel.from_pretrained(name, **kwargs) - return model diff --git a/chameleon/transformers/token_mixer.py b/chameleon/transformers/token_mixer.py deleted file mode 100644 index 744a89a..0000000 --- a/chameleon/transformers/token_mixer.py +++ /dev/null @@ -1,327 +0,0 @@ -from typing import Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ..nn.components import LayerNorm2d, build_activation -from ..nn.mbcnn import MBCNNcell - -__all__ = [ - 'Attention', 'AttentionMixing', 'RandomMixing', 'SepConvMixing', - 'PoolMixing', 'SelfAttention', -] - - -class SelfAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int = 8, - dropout: float = 0., - bias: bool = True, - ) -> None: - """ - Initialize the multi-head attention mechanism. - - Args: - embed_dim (int): - Dimensionality of the input and output feature vectors. - num_heads (int, optional): - Number of attention heads, defaults to 8. - dropout (float, optional): - Dropout rate, defaults to 0. - bias (bool, optional): - Whether to include bias in the projection layers, defaults to True. - """ - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * \ - num_heads == self.embed_dim, "embed_dim must be divisible by num_heads." - self.in_proj_query = nn.Linear(embed_dim, embed_dim, bias=bias) - self.in_proj_key = nn.Linear(embed_dim, embed_dim, bias=bias) - self.in_proj_value = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.dropout_layer = nn.Dropout(dropout) - - def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): - """ - Forward pass for the multi-head attention mechanism. - - Args: - query (Tensor): - Query tensor of shape (batch_size, seq_len, embed_dim). - key (Tensor): - Key tensor of shape (batch_size, seq_len, embed_dim). - value (Tensor): - Value tensor of shape (batch_size, seq_len, embed_dim). - attn_mask (Optional[Tensor]): - Mask to be added to attention scores before softmax. - Default: None. - key_padding_mask (Optional[Tensor]): - Mask indicating which elements in the key sequence should be ignored. - Default: None. - """ - Q = self.in_proj_query(query) - K = self.in_proj_key(key) - V = self.in_proj_value(value) - - # Split into multiple heads - Q = Q.view(Q.size(0), Q.size(1), self.num_heads, - self.head_dim).transpose(1, 2) - K = K.view(K.size(0), K.size(1), self.num_heads, - self.head_dim).transpose(1, 2) - V = V.view(V.size(0), V.size(1), self.num_heads, - self.head_dim).transpose(1, 2) - - # Scaled dot-product attention - attn_output_weights = torch.matmul( - Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) - - # Apply the key padding mask - if key_padding_mask is not None: - attn_output_weights.masked_fill_( - key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) - - if attn_mask is not None: - attn_output_weights += attn_mask - attn_output_weights = F.softmax(attn_output_weights, dim=-1) - attn_output_weights = self.dropout_layer(attn_output_weights) - - # Get final output - attn_output = torch.matmul(attn_output_weights, V) - attn_output = attn_output.transpose(1, 2).contiguous().view( - attn_output.size(0), -1, self.embed_dim) - - return self.out_proj(attn_output), attn_output_weights - - -class Attention(nn.Module): - - def __init__( - self, - in_features: int, - num_heads: int = 8, - qkv_bias: bool = True, - return_attn: bool = False, - add_output_layer: bool = True, - is_cross_attention: bool = False, - ) -> None: - """ - Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762. - Modified from timm. - - Args: - dim (int): - Number of input channels. - head_dim (int, optional): - Dimensionality of the output of each head, defaults to 32. - num_heads (int, optional): - Number of attention heads, defaults to None (uses `dim` divided by `head_dim`). - qkv_bias (bool, optional): - Whether to include bias in the projection layers, defaults to False. - return_attn (bool, optional): - Whether to return the attention map, defaults to False. - add_output_layer (bool, optional): - Whether to add an output layer, defaults to True. - is_cross_attention (bool, optional): - Whether this is cross-attention, defaults to False. - """ - super().__init__() - assert in_features % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = in_features // num_heads - self.scale = self.head_dim ** -0.5 - self.return_attn = return_attn - self.is_cross_attention = is_cross_attention - - self.proj = nn.Linear(in_features, in_features) \ - if add_output_layer else nn.Identity() - - if self.is_cross_attention: - self.q = nn.Linear(in_features, in_features, bias=qkv_bias) - self.kv = nn.Linear(in_features, in_features * 2, bias=qkv_bias) - else: - self.qkv = nn.Linear(in_features, in_features * 3, bias=qkv_bias) - - def forward(self, x: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor: - """ - Applies self-attention to the input tensor. - - Args: - x: - Input tensor of shape (batch_size, seq_len, dim). - hidden_state: - Hidden state of the previous token, used for cross-attention. - - Returns: - A tuple containing the output tensor of shape (batch_size, seq_len, dim) and - the attention tensor of shape (batch_size, num_heads, seq_len, seq_len). - """ - B, N, C = x.shape - - if self.is_cross_attention: - q = self.q(x) - kv = self.kv(hidden_state) - k, v = torch.chunk(kv, 2, dim=-1) - else: - qkv = self.qkv(x) - q, k, v = torch.chunk(qkv, 3, dim=-1) - - q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) - v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) - - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - - x = attn @ v - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - - if self.return_attn: - return x, attn - - return x - - -class AttentionMixing(Attention): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Applies self-attention to the input tensor. - - Args: - x (torch.Tensor): - Input tensor of shape (batch_size, in_channels, height, width). - - Returns: - A tensor of the same shape as input after applying the self-attention. - """ - B, C, H, W = x.shape - x = x.reshape(B, C, H * W).permute(0, 2, 1) - x = super().forward(x) - if self.return_attn: - x, attn = x - x = x.permute(0, 2, 1).reshape(B, C, H, W) - if self.return_attn: - return x, attn - return x - - -class RandomMixing(nn.Module): - - def __init__(self, num_tokens: int): - """ Random mixing of tokens. - Args: - num_tokens (int): - Number of tokens. - """ - super().__init__() - self.random_matrix = nn.parameter.Parameter( - torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), - requires_grad=False) - - def forward(self, x): - """ - Applies random-attention to the input tensor. - - Args: - x (torch.Tensor): - Input tensor of shape (batch_size, in_channels, height, width). - - Returns: - A tensor of the same shape as input after applying the random-attention. - """ - B, C, H, W = x.shape - x = x.reshape(B, C, H * W) - x = torch.einsum('mn, bcn -> bcm', self.random_matrix, x) - x = x.reshape(B, C, H, W) - return x - - -class SepConvMixing(nn.Module): - - def __init__( - self, - in_features: int, - expand_ratio: float = 2, - kernel_size: Union[int, Tuple[int, int]] = 7, - inner_act: Union[dict, nn.Module] = {'name': 'StarReLU'}, - ) -> None: - """ - SepConvMixing is an inverted separable convolution block from MobileNetV2. - It performs a depthwise convolution followed by a pointwise convolution. - Ref: https://arxiv.org/abs/1801.04381. - - Args: - in_channels (int): - Number of input channels. - expand_ratio (float): - Expansion ratio of the hidden channels. Defaults to 2. - kernel_size (Union[int, Tuple[int, int]]): - Size of the depthwise convolution kernel. Defaults to 7. - inner_act (Union[dict, nn.Module]): - Activation function to be used internally. Defaults to StarReLU. - """ - super().__init__() - hid_channels = int(in_features * expand_ratio) - self.mbcnn_v2 = MBCNNcell( - in_channels=in_features, - out_channels=in_features, - hid_channels=hid_channels, - kernel=kernel_size, - norm=LayerNorm2d(in_features), - inner_norm=LayerNorm2d(hid_channels), - inner_act=inner_act if isinstance( - inner_act, nn.Module) else build_activation(**inner_act), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Perform a SepConvMixing operation on the input tensor. - - Args: - x (torch.Tensor): - Input tensor of shape (batch_size, in_channels, height, width). - - Returns: - A tensor of the same shape as the input after applying mbcnn_v2 module. - """ - return self.mbcnn_v2(x) - - -class PoolMixing(nn.Module): - - def __init__(self, pool_size: int = 3): - """ - Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 - - Args: - pool_size (int): Size of the pooling window. - """ - super().__init__() - self.pool = nn.AvgPool2d( - pool_size, - stride=1, - padding=pool_size//2, - count_include_pad=False - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Apply pooling and subtract the result from the input. - - Args: - x (torch.Tensor): - Input tensor of shape (batch_size, in_channels, height, width). - - Returns: - A tensor of the same shape as input after applying the pooling and subtraction. - """ - return self.pool(x) - x diff --git a/chameleon/transformers/utils.py b/chameleon/transformers/utils.py deleted file mode 100644 index 21bf339..0000000 --- a/chameleon/transformers/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Tuple, Union - -from huggingface_hub import list_models - -__all__ = ['list_models_transformers', 'calculate_patch_size'] - - -def list_models_transformers(*args, **kwargs): - models = list(iter(list_models(*args, **kwargs))) - return [m.modelId for m in models] - - -def calculate_patch_size( - image_size: Union[int, Tuple[int, int]], - num_patches: Union[int, Tuple[int, int]], -) -> Tuple[int, int]: - ''' - Calculate the number of patches that can fit into an image. - - Args: - image_size (Union[int, Tuple[int, int]]): The size of the image. - num_patches (Union[int, Tuple[int, int]]): The number of the patch. - - Returns: - Tuple[int, int]: The number of patches that can fit into the image. - ''' - if isinstance(image_size, int): - image_size = (image_size, image_size) - if isinstance(num_patches, int): - num_patches = (num_patches, num_patches) - if image_size[0] % num_patches[0]: - raise ValueError( - f'`image_size` {image_size[0]} can not divided with `{num_patches[0]}`.') - if image_size[1] % num_patches[1]: - raise ValueError( - f'`image_size` {image_size[1]} can not divided with `{num_patches[1]}`.') - patch_size = ( - image_size[0] // num_patches[0], - image_size[1] // num_patches[1] - ) - return patch_size diff --git a/chameleon/transformers/vit.py b/chameleon/transformers/vit.py deleted file mode 100644 index faef4c6..0000000 --- a/chameleon/transformers/vit.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import List, Tuple, Union - -import torch -import torch.nn as nn -from transformers import ViTConfig, ViTModel - -from .utils import list_models_transformers - -__all__ = ['ViT'] - - -class ViT(nn.Module): - - def __init__( - self, - hidden_size: int = 768, - num_hidden_layers: int = 12, - num_attention_heads: int = 12, - intermediate_size: int = 3072, - hidden_act: str = 'relu', - hidden_dropout_prob: float = 0.0, - attention_probs_dropout_prob: float = 0.0, - initializer_range: float = 0.02, - layer_norm_eps: float = 1e-12, - image_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - num_channels: int = 3, - qkv_bias: bool = True, - encoder_stride: int = 16, - **kwargs, - ) -> None: - """ - ViT: Vision Transformer - A transformer model for image classification - - Args: - hidden_size (int, optional): - Dimensionality of the encoder layers and the pooler layer. - Default is 768. - num_hidden_layers (int, optional): - Number of hidden layers in the Transformer encoder. - Default is 12. - num_attention_heads (int, optional): - Number of attention heads for each attention layer in the - Transformer encoder. - Default is 12. - intermediate_size (int, optional): - Dimensionality of the "intermediate" (i.e., feed-forward) layer - in the Transformer encoder. - Default is 3072. - hidden_act (str, optional): - The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu", "selu" and "gelu_new" - are supported. - Default is "relu". - hidden_dropout_prob (float, optional): - The dropout probability for all fully connected layers in the - embeddings, encoder, and pooler. - Default is 0.0. - attention_probs_dropout_prob (float, optional): - The dropout ratio for the attention probabilities. - Default is 0.0. - initializer_range (float, optional): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - Default is 0.02. - layer_norm_eps (float, optional): - The epsilon used by the layer normalization layers. - Default is 1e-12. - image_size (Union[int, Tuple[int, int]], optional): - The size (resolution) of each image. - Default is 224. - patch_size (Union[int, Tuple[int, int]], optional): - The size (resolution) of each patch. - Default is 16. - num_channels (int, optional): - The number of input channels. - Default is 3. - qkv_bias (bool, optional): - Whether to add a bias to the queries, keys and values. - Default is True. - encoder_stride (int, optional): - Factor to increase the spatial resolution by in the decoder head - for masked image modeling. - Default is 16. - """ - super().__init__() - self.config = ViTConfig( - hidden_size=hidden_size, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - hidden_act=hidden_act, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - initializer_range=initializer_range, - layer_norm_eps=layer_norm_eps, - image_size=image_size, - patch_size=patch_size, - num_channels=num_channels, - qkv_bias=qkv_bias, - encoder_stride=encoder_stride, - **kwargs, - ) - self.model = ViTModel(self.config, add_pooling_layer=False) - - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - hidden_state = self.model(x).last_hidden_state - cls_token, hidden_state = torch.split(hidden_state, [1, hidden_state.shape[1] - 1], dim=1) - return cls_token.squeeze(dim=1), hidden_state - - @staticmethod - def list_models(author='google', search='vit') -> List[str]: - return list_models_transformers(author=author, search=search) - - @classmethod - def from_pretrained(cls, name, **kwargs) -> 'ViT': - model = cls(**kwargs) - kwargs.update({'add_pooling_layer': False}) - model.model = ViTModel.from_pretrained(name, **kwargs) - return model diff --git a/setup.cfg b/setup.cfg index a0dbf76..c12f2db 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,17 +25,12 @@ setup_requires= pip setuptools wheel + install_requires = timm>=0.5.4 - scikit-learn - transformers - torch>=2.4.0 + torch>=1.13 torchvision torchmetrics - albumentations - ptflops==0.7.0 - calflops - rich [options.packages.find] exclude = diff --git a/tests/base/blocks/test_conv_block.py b/tests/base/blocks/test_conv_block.py new file mode 100644 index 0000000..659328b --- /dev/null +++ b/tests/base/blocks/test_conv_block.py @@ -0,0 +1,137 @@ +import pytest +import torch +import torch.nn as nn + +from chameleon.base.blocks import Conv2dBlock, SeparableConv2dBlock + + +@pytest.fixture +def cnn_arch(): + return [ + {'in_channels': 3, 'out_channels': 32, 'kernel': 3}, + {'in_channels': 32, 'out_channels': 64, 'kernel': 3}, + {'in_channels': 64, 'out_channels': 128, 'kernel': 3}, + ] + + +@pytest.fixture +def fc_arch(): + return [ + {'in_channels': 3, 'out_channels': 32}, + {'in_channels': 32, 'out_channels': 64}, + {'in_channels': 64, 'out_channels': 128}, + ] + + +def test_SeparableConv2dBlock_forward(): + # Test input and output shapes + in_channels = 64 + out_channels = 128 + block = SeparableConv2dBlock(in_channels, out_channels) + x = torch.randn(1, in_channels, 64, 64) + output = block(x) + assert output.shape == (1, out_channels, 64, 64) + + # Test with different kernel size and padding + kernel_size = (5, 3) + padding = (1, 2) + block = SeparableConv2dBlock(in_channels, out_channels, kernel=kernel_size, padding=padding) + output = block(x) + assert output.shape == (1, out_channels, 62, 66) + + # Test with different stride + stride = 2 + block = SeparableConv2dBlock(in_channels, out_channels, stride=stride) + output = block(x) + assert output.shape == (1, out_channels, 32, 32) + + # Test with different output channels + out_channels = 32 + block = SeparableConv2dBlock(in_channels, out_channels) + output = block(x) + assert output.shape == (1, out_channels, 64, 64) + + # Test without normalization and activation + block = SeparableConv2dBlock(in_channels, out_channels, norm=None, act=None) + output = block(x) + assert output.shape == (1, out_channels, 64, 64) + + +def test_SeparableConv2dBlock_build_component(): + # Test build_component() function with different activation functions + activation_fns = [ + {'name': 'ReLU'}, + {'name': 'Sigmoid'}, + {'name': 'Tanh'}, + {'name': 'LeakyReLU', 'negative_slope': 0.2} + ] + for act in activation_fns: + block = SeparableConv2dBlock(64, 64, act=act) + assert isinstance(block.act, nn.Module) + + +def test_SeparableConv2dBlock_build_component(): + # Test build_component() function with different normalization layers + norm_layers = [ + {'name': 'BatchNorm2d', 'num_features': 64}, + {'name': 'InstanceNorm2d', 'num_features': 64}, + {'name': 'GroupNorm', 'num_groups': 8, 'num_channels': 64}, + ] + tgt_norms = [ + nn.BatchNorm2d, + nn.InstanceNorm2d, + nn.GroupNorm + ] + for norm, tgt in zip(norm_layers, tgt_norms): + block = SeparableConv2dBlock(64, 64, norm=norm) + assert isinstance(block.block['norm'], tgt) + + +@pytest.fixture +def input_tensor(): + return torch.randn((2, 3, 32, 32)) + + +@pytest.fixture +def output_shape(): + return (2, 16, 32, 32) + + +def test_Conv2dBlock_forward(input_tensor, output_shape): + model = Conv2dBlock(in_channels=3, out_channels=16) + output = model(input_tensor) + assert output.shape == output_shape + + +def test_Conv2dBlock_with_activation(input_tensor, output_shape): + model = Conv2dBlock(in_channels=3, out_channels=16, act={'name': 'ReLU', 'inplace': True}) + output = model(input_tensor) + assert output.shape == output_shape + assert torch.all(output >= 0) + + +def test_Conv2dBlock_with_batch_norm(input_tensor, output_shape): + model = Conv2dBlock(in_channels=3, out_channels=16, norm={'name': 'BatchNorm2d', 'num_features': 16}) + output = model(input_tensor) + assert output.shape == output_shape + assert torch.allclose(output.mean(dim=(0, 2, 3)), torch.zeros(16), rtol=1e-3, atol=1e-5) + assert torch.allclose(output.var(dim=(0, 2, 3)), torch.ones(16), rtol=1e-3, atol=1e-5) + + +def test_Conv2dBlock_init_type(input_tensor): + model = Conv2dBlock(in_channels=3, out_channels=16, init_type='uniform') + output1 = model(input_tensor) + model = Conv2dBlock(in_channels=3, out_channels=16, init_type='normal') + output2 = model(input_tensor) + assert not torch.allclose(output1, output2, rtol=1e-3, atol=1e-5) + + +def test_Conv2dBlock_all_together(input_tensor): + model = Conv2dBlock(in_channels=3, out_channels=16, + kernel=5, stride=2, padding=2, dilation=2, groups=1, + bias=True, padding_mode='reflect', + norm={'name': 'BatchNorm2d', 'num_features': 16, 'momentum': 0.5}, + act={'name': 'LeakyReLU', 'negative_slope': 0.1, 'inplace': True}, + init_type='uniform') + output = model(input_tensor) + assert output.shape == (2, 16, 14, 14) diff --git a/tests/nn/component/test_activation.py b/tests/base/components/test_activation.py similarity index 72% rename from tests/nn/component/test_activation.py rename to tests/base/components/test_activation.py index 3878119..24d62ed 100644 --- a/tests/nn/component/test_activation.py +++ b/tests/base/components/test_activation.py @@ -1,25 +1,26 @@ import pytest import torch +import torch.nn as nn -from chameleon.nn.components import SquaredReLU, StarReLU, build_activation +from chameleon.base.components import SquaredReLU, StarReLU, build_component -test_build_activation_data = [ - ('ReLU', torch.nn.ReLU), - ('LeakyReLU', torch.nn.LeakyReLU), - ('Swish', torch.nn.SiLU), +test_build_component_data = [ + ('ReLU', nn.ReLU), + ('LeakyReLU', nn.LeakyReLU), + ('Swish', nn.SiLU), ('StarReLU', StarReLU), ('SquaredReLU', SquaredReLU), ('FakeActivation', ValueError) ] -@pytest.mark.parametrize('name, expected_output', test_build_activation_data) -def test_build_activation(name, expected_output): +@pytest.mark.parametrize('name, expected_output', test_build_component_data) +def test_build_component(name, expected_output): if expected_output == ValueError: with pytest.raises(ValueError): - build_activation(name) + build_component(name) else: - assert isinstance(build_activation(name), expected_output) + assert isinstance(build_component(name), expected_output) def test_starrelu(): diff --git a/tests/nn/component/test_loss.py b/tests/base/components/test_loss.py similarity index 95% rename from tests/nn/component/test_loss.py rename to tests/base/components/test_loss.py index 20ab5d4..3d7a060 100644 --- a/tests/nn/component/test_loss.py +++ b/tests/base/components/test_loss.py @@ -1,7 +1,7 @@ import pytest import torch -from chameleon import AWingLoss, WeightedAWingLoss +from chameleon.base.components import AWingLoss, WeightedAWingLoss @pytest.fixture(scope='module') diff --git a/tests/nn/component/test_norm.py b/tests/base/components/test_norm.py similarity index 92% rename from tests/nn/component/test_norm.py rename to tests/base/components/test_norm.py index 9eb4dc1..803559f 100644 --- a/tests/nn/component/test_norm.py +++ b/tests/base/components/test_norm.py @@ -7,7 +7,7 @@ from torch.nn.modules.normalization import (CrossMapLRN2d, GroupNorm, LayerNorm, LocalResponseNorm) -from chameleon import LayerNorm2d, build_norm +from chameleon.base.components import LayerNorm2d, build_component NORM_CLASSES = { 'BatchNorm1d': BatchNorm1d, @@ -26,7 +26,7 @@ @pytest.mark.parametrize('name', NORM_CLASSES.keys()) -def test_build_norm(name: str) -> None: +def test_build_component(name: str) -> None: options = {} cls = NORM_CLASSES[name] if name.startswith('BatchNorm'): @@ -47,7 +47,7 @@ def test_build_norm(name: str) -> None: options['size'] = 3 elif name.startswith('CrossMapLRN2d'): options['size'] = 3 - norm = build_norm(name, **options) + norm = build_component(name, **options) assert isinstance(norm, cls) diff --git a/tests/nn/component/test_pool.py b/tests/base/components/test_pooling.py similarity index 90% rename from tests/nn/component/test_pool.py rename to tests/base/components/test_pooling.py index 3b01766..e2a4ece 100644 --- a/tests/nn/component/test_pool.py +++ b/tests/base/components/test_pooling.py @@ -1,7 +1,7 @@ import pytest import torch -from chameleon import build_pool +from chameleon import build_component @pytest.fixture @@ -22,7 +22,7 @@ def input_tensor(): @pytest.mark.parametrize('name, kwargs, expected_shape', pool_layers) def test_pool_layer(name, kwargs, expected_shape, input_tensor): # Build the pool layer - layer = build_pool(name, **kwargs) + layer = build_component(name, **kwargs) # Check the output shape output = layer(input_tensor) diff --git a/tests/nn/test_aspp.py b/tests/base/layers/test_aspp.py similarity index 75% rename from tests/nn/test_aspp.py rename to tests/base/layers/test_aspp.py index af22843..22959c8 100644 --- a/tests/nn/test_aspp.py +++ b/tests/base/layers/test_aspp.py @@ -1,7 +1,8 @@ import pytest import torch -from chameleon.nn import ASPPLayer, Hswish +from chameleon.base.components.activation import Hswish +from chameleon.base.layers import ASPP @pytest.fixture @@ -14,17 +15,17 @@ def test_aspp_layer(input_tensor): out_channels = 128 # Test default activation function (ReLU) - aspp_layer = ASPPLayer(in_channels, out_channels) + aspp_layer = ASPP(in_channels, out_channels) output = aspp_layer(input_tensor) assert output.size() == (1, out_channels, 32, 32) # Test with Hswish activation function - aspp_layer = ASPPLayer(in_channels, out_channels, output_activate=Hswish()) + aspp_layer = ASPP(in_channels, out_channels, output_activate=Hswish()) output = aspp_layer(input_tensor) assert output.size() == (1, out_channels, 32, 32) # Test with different dilation rates - aspp_layer = ASPPLayer(in_channels, out_channels) + aspp_layer = ASPP(in_channels, out_channels) aspp_layer.layers['DILATE1'].dilation = (2, 2) aspp_layer.layers['DILATE2'].dilation = (4, 4) aspp_layer.layers['DILATE3'].dilation = (8, 8) diff --git a/tests/nn/test_grl.py b/tests/base/layers/test_grl.py similarity index 94% rename from tests/nn/test_grl.py rename to tests/base/layers/test_grl.py index cc397e9..940ca15 100644 --- a/tests/nn/test_grl.py +++ b/tests/base/layers/test_grl.py @@ -1,6 +1,6 @@ import torch -from chameleon.nn import GradientReversalLayer +from chameleon.base.layers import GradientReversalLayer def test_gradient_reversal_layer(): diff --git a/tests/nn/test_selayer.py b/tests/base/layers/test_selayer.py similarity index 86% rename from tests/nn/test_selayer.py rename to tests/base/layers/test_selayer.py index 56e6901..cc8ced0 100644 --- a/tests/nn/test_selayer.py +++ b/tests/base/layers/test_selayer.py @@ -1,6 +1,6 @@ import torch -from chameleon.nn import SELayer +from chameleon.base.layers import SELayer def test_selayer_output_shape(): @@ -48,5 +48,5 @@ def test_selayer_reduction(): expected_channels = in_channels // reduction - assert se_layer.fc1.layer.cnn.out_channels == expected_channels - assert se_layer.fc2.layer.cnn.out_channels == in_channels + assert se_layer.fc1.block['conv'].out_channels == expected_channels + assert se_layer.fc2.block['conv'].out_channels == in_channels diff --git a/tests/nn/test_vae.py b/tests/base/layers/test_vae.py similarity index 95% rename from tests/nn/test_vae.py rename to tests/base/layers/test_vae.py index efd6a4e..3145948 100644 --- a/tests/nn/test_vae.py +++ b/tests/base/layers/test_vae.py @@ -1,7 +1,7 @@ import pytest import torch -from chameleon.nn import VAE +from chameleon.base.layers import VAE @pytest.fixture diff --git a/tests/nn/test_weightedsum.py b/tests/base/layers/test_weighted_sum.py similarity index 96% rename from tests/nn/test_weightedsum.py rename to tests/base/layers/test_weighted_sum.py index 60a49c0..25e1657 100644 --- a/tests/nn/test_weightedsum.py +++ b/tests/base/layers/test_weighted_sum.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from chameleon.nn import WeightedSum +from chameleon.base.layers import WeightedSum def test_weighted_sum_init(): diff --git a/tests/nn/test_positional_encoding.py b/tests/base/ops/test_positional_encoding.py similarity index 92% rename from tests/nn/test_positional_encoding.py rename to tests/base/ops/test_positional_encoding.py index c051345..cf7238e 100644 --- a/tests/nn/test_positional_encoding.py +++ b/tests/base/ops/test_positional_encoding.py @@ -1,7 +1,7 @@ import pytest import torch -from chameleon.nn import sinusoidal_positional_encoding_1d +from chameleon.base.ops import sinusoidal_positional_encoding_1d @pytest.mark.parametrize("length, dim", [(10, 16), (20, 32), (5, 8)]) diff --git a/tests/nn/test_PowerModule.py b/tests/base/test_power_module.py similarity index 82% rename from tests/nn/test_PowerModule.py rename to tests/base/test_power_module.py index ad45ea5..08f36a3 100644 --- a/tests/nn/test_PowerModule.py +++ b/tests/base/test_power_module.py @@ -1,7 +1,7 @@ import pytest import torch -from chameleon.nn.utils import PowerModule, initialize_weights +from chameleon.base.power_module import PowerModule class SimpleModel(PowerModule): @@ -16,12 +16,6 @@ def model(): return SimpleModel() -def test_initialize_weights(model): - initialize_weights(model) - for param in model.parameters(): - assert not torch.isnan(param).any() - - def test_freeze(model): model.freeze(verbose=True) for param in model.parameters(): diff --git a/tests/base/test_utils.py b/tests/base/test_utils.py new file mode 100644 index 0000000..fc94d6e --- /dev/null +++ b/tests/base/test_utils.py @@ -0,0 +1,44 @@ +import pytest +import torch +import torch.nn as nn + +from chameleon.base.utils import (has_children, initialize_weights_, + replace_module, replace_module_attr_value) + + +def test_has_children(): + model = nn.Sequential( + nn.Conv2d(3, 64, 3), + nn.ReLU(), + nn.Sequential( + nn.Conv2d(64, 64, 3), + nn.ReLU(), + ) + ) + assert has_children(model) + + +def test_replace_module(): + model = nn.Sequential( + nn.Conv2d(3, 64, 3), + nn.ReLU(), + ) + replace_module(model, nn.ReLU, nn.Sigmoid()) + assert model[1].__class__ == nn.Sigmoid + + +def test_replace_module_attr_value(): + model = nn.Sequential( + nn.Conv2d(3, 64, 3), + nn.ReLU(), + ) + replace_module_attr_value(model, nn.ReLU, 'inplace', True) + assert model[1].inplace == True + + +def test_initialize_weights_(): + model = nn.Conv2d(3, 64, 3) + model.weight.data.fill_(0) + initialize_weights_(model, 'normal') + for param in model.parameters(): + assert not torch.isnan(param).any() diff --git a/tests/efficientdet/test_efficientdet.py b/tests/efficientdet/test_efficientdet.py deleted file mode 100644 index 3041fec..0000000 --- a/tests/efficientdet/test_efficientdet.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -import torch - -from chameleon import EfficientDet - - -@pytest.fixture -def input_tensor(): - # create a sample input tensor - return torch.rand((1, 3, 512, 512)) - - -@pytest.mark.parametrize("compound_coef, pretrained", [ - (0, True), - (1, True), - (2, True), - (3, True), - (4, True), - (5, True), - (6, False), - (7, False), - (8, False), - (0, False), -]) -def test_efficientdet_backbone(input_tensor, compound_coef, pretrained): - # create the model with the specified compound_coef and pretrained options - model = EfficientDet(compound_coef=compound_coef, pretrained=pretrained) - - # verify that the model is PowerModule and nn.Module - assert isinstance(model, EfficientDet) - assert isinstance(model, torch.nn.Module) - - # verify that the forward pass of the model returns a list of feature maps - output = model(input_tensor) - assert isinstance(output, list) - - # verify that the shape of each feature map in the output list is correct - conv_channel_coef = { - 0: [40, 112, 320], - 1: [40, 112, 320], - 2: [48, 120, 352], - 3: [48, 136, 384], - 4: [56, 160, 448], - 5: [64, 176, 512], - 6: [72, 200, 576], - 7: [80, 224, 640], - 8: [88, 248, 704], - } - - for i in range(len(output)): - expected_shape = ( - 1, - model.fpn_num_filters[compound_coef], - int(input_tensor.shape[2] / 2 ** (i+3)), - int(input_tensor.shape[3] / 2 ** (i+3)) - ) - - assert output[i].shape == expected_shape diff --git a/tests/backbone/test_backbone.py b/tests/modules/backbone/test_backbone.py similarity index 100% rename from tests/backbone/test_backbone.py rename to tests/modules/backbone/test_backbone.py diff --git a/tests/neck/test_bifpn.py b/tests/modules/neck/test_bifpn.py similarity index 98% rename from tests/neck/test_bifpn.py rename to tests/modules/neck/test_bifpn.py index 259d3e1..738e682 100644 --- a/tests/neck/test_bifpn.py +++ b/tests/modules/neck/test_bifpn.py @@ -1,6 +1,6 @@ import torch -from chameleon.neck import BiFPN, BiFPNs +from chameleon.modules.necks import BiFPN, BiFPNs def test_bifpn(): diff --git a/tests/modules/neck/test_fpn.py b/tests/modules/neck/test_fpn.py new file mode 100644 index 0000000..1c8634e --- /dev/null +++ b/tests/modules/neck/test_fpn.py @@ -0,0 +1,82 @@ +import torch + +from chameleon.modules.necks import FPN + + +def test_fpn(): + in_channels_list = [256, 512, 1024, 2048] + out_channels = 256 + fpn = FPN(in_channels_list, out_channels, extra_layers=2, out_indices=[0, 1, 2, 3]) + x1 = torch.randn(1, in_channels_list[0], 128, 128) + x2 = torch.randn(1, in_channels_list[1], 64, 64) + x3 = torch.randn(1, in_channels_list[2], 32, 32) + x4 = torch.randn(1, in_channels_list[3], 16, 16) + feats = [x1, x2, x3, x4] + outs = fpn(feats) + assert len(outs) == 4 + assert fpn.conv1x1s[0].__class__.__name__ == 'Identity' + for out in outs: + assert out.shape[0] == 1 + assert out.shape[1] == out_channels + assert out.shape[2] == out.shape[3] + + +def test_build_fpn(): + in_channels_list = [256, 512, 1024, 2048] + out_channels = 256 + extra_layers = 2 + upsample_mode = 'bilinear' + out_indices = [0, 1, 2, 3] + fpn = FPN.build_fpn(in_channels_list, out_channels, extra_layers, out_indices, upsample_mode) + assert isinstance(fpn, FPN) + + +def test_build_dwfpn(): + in_channels_list = [256, 512, 1024, 2048] + out_channels = 256 + extra_layers = 2 + upsample_mode = 'bilinear' + out_indices = [0, 1, 2, 3] + fpn = FPN.build_dwfpn(in_channels_list, out_channels, extra_layers, out_indices, upsample_mode) + assert isinstance(fpn, FPN) + + +# def test_fpns_module(): +# # Define test inputs +# in_channels_list = [64, 128, 256] +# out_channels = 256 +# n_fpn = 3 +# extra_layers = 2 +# out_indices = [0, 2] +# upsample_mode = 'nearest' + +# # Initialize FPNs module +# fpns = FPNs( +# in_channels_list=in_channels_list, +# out_channels=out_channels, +# n_fpn=n_fpn, +# extra_layers=extra_layers, +# out_indices=out_indices, +# upsample_mode=upsample_mode, +# ) + +# # Generate test inputs +# input_shapes = [(1, in_channels, 32 // 2**i, 32 // 2**i) for i, in_channels in enumerate(in_channels_list)] +# inputs = [torch.randn(shape) for shape in input_shapes] + +# # Test forward pass +# output_shapes = [(1, out_channels, 32 // 2**i, 32 // 2**i) for i in range(len(in_channels_list))] +# output = fpns(inputs) +# assert isinstance(output, list) +# assert len(output) == 2 +# for i, idx in enumerate(out_indices): +# assert output[i].shape == output_shapes[idx] + +# # Test if the upsample_mode is correct +# for i in range(n_fpn): +# assert fpns.block[i].upsample_mode == upsample_mode + +# # Test if the input channels are correct +# for i in range(n_fpn): +# expected_in_channels = [out_channels] * (len(in_channels_list) + extra_layers) if i != 0 else in_channels_list +# assert fpns.block[i].in_channels_list == expected_in_channels diff --git a/tests/neck/test_neck.py b/tests/modules/neck/test_neck.py similarity index 94% rename from tests/neck/test_neck.py rename to tests/modules/neck/test_neck.py index bfb3b4e..e3157a7 100644 --- a/tests/neck/test_neck.py +++ b/tests/modules/neck/test_neck.py @@ -1,7 +1,7 @@ import pytest import torch -from chameleon.neck import build_neck, list_necks +from chameleon.modules import build_neck, list_necks INPUT1 = [ torch.rand(1, 16, 80, 80), @@ -56,7 +56,7 @@ def test_build_backbone(in_tensor, build_kwargs, expected): data = [ ( '', - ['fpn', 'fpns', 'bifpn', 'bifpns'] + ['fpn', 'bifpn', 'bifpns'] ), ( '*bi*', diff --git a/tests/neck/test_fpn.py b/tests/neck/test_fpn.py deleted file mode 100644 index 4cdf909..0000000 --- a/tests/neck/test_fpn.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch - -from chameleon.neck import FPN, FPNs - - -def test_fpn(): - in_channels_list = [256, 512, 1024, 2048] - out_channels = 256 - fpn = FPN(in_channels_list, out_channels, extra_layers=2, out_indices=[0, 1, 2, 3]) - x1 = torch.randn(1, in_channels_list[0], 128, 128) - x2 = torch.randn(1, in_channels_list[1], 64, 64) - x3 = torch.randn(1, in_channels_list[2], 32, 32) - x4 = torch.randn(1, in_channels_list[3], 16, 16) - feats = [x1, x2, x3, x4] - outs = fpn(feats) - assert len(outs) == 4 - assert fpn.conv1x1s[0].__class__.__name__ == 'Identity' - for out in outs: - assert out.shape[0] == 1 - assert out.shape[1] == out_channels - assert out.shape[2] == out.shape[3] - - -def test_build_fpn(): - in_channels_list = [256, 512, 1024, 2048] - out_channels = 256 - extra_layers = 2 - upsample_mode = 'bilinear' - out_indices = [0, 1, 2, 3] - fpn = FPN.build_fpn(in_channels_list, out_channels, extra_layers, out_indices, upsample_mode) - assert isinstance(fpn, FPN) - - -def test_build_dwfpn(): - in_channels_list = [256, 512, 1024, 2048] - out_channels = 256 - extra_layers = 2 - upsample_mode = 'bilinear' - out_indices = [0, 1, 2, 3] - fpn = FPN.build_dwfpn(in_channels_list, out_channels, extra_layers, out_indices, upsample_mode) - assert isinstance(fpn, FPN) - - -def test_fpns_module(): - # Define test inputs - in_channels_list = [64, 128, 256] - out_channels = 256 - n_fpn = 3 - extra_layers = 2 - out_indices = [0, 2] - upsample_mode = 'nearest' - - # Initialize FPNs module - fpns = FPNs( - in_channels_list=in_channels_list, - out_channels=out_channels, - n_fpn=n_fpn, - extra_layers=extra_layers, - out_indices=out_indices, - upsample_mode=upsample_mode, - ) - - # Generate test inputs - input_shapes = [(1, in_channels, 32 // 2**i, 32 // 2**i) for i, in_channels in enumerate(in_channels_list)] - inputs = [torch.randn(shape) for shape in input_shapes] - - # Test forward pass - output_shapes = [(1, out_channels, 32 // 2**i, 32 // 2**i) for i in range(len(in_channels_list))] - output = fpns(inputs) - assert isinstance(output, list) - assert len(output) == 2 - for i, idx in enumerate(out_indices): - assert output[i].shape == output_shapes[idx] - - # Test if the upsample_mode is correct - for i in range(n_fpn): - assert fpns.block[i].upsample_mode == upsample_mode - - # Test if the input channels are correct - for i in range(n_fpn): - expected_in_channels = [out_channels] * (len(in_channels_list) + extra_layers) if i != 0 else in_channels_list - assert fpns.block[i].in_channels_list == expected_in_channels diff --git a/tests/nn/test_block.py b/tests/nn/test_block.py deleted file mode 100644 index 1ed6d5b..0000000 --- a/tests/nn/test_block.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -from chameleon.nn import SeparableConvBlock - - -@pytest.fixture -def cnn_arch(): - return [ - {'in_channels': 3, 'out_channels': 32, 'kernel': 3}, - {'in_channels': 32, 'out_channels': 64, 'kernel': 3}, - {'in_channels': 64, 'out_channels': 128, 'kernel': 3}, - ] - - -@pytest.fixture -def fc_arch(): - return [ - {'in_channels': 3, 'out_channels': 32}, - {'in_channels': 32, 'out_channels': 64}, - {'in_channels': 64, 'out_channels': 128}, - ] - - -def test_SeparableConvBlock_forward(): - # Test input and output shapes - in_channels = 64 - out_channels = 128 - block = SeparableConvBlock(in_channels, out_channels) - x = torch.randn(1, in_channels, 64, 64) - output = block(x) - assert output.shape == (1, out_channels, 64, 64) - - # Test with different kernel size and padding - kernel_size = (5, 3) - padding = (1, 2) - block = SeparableConvBlock(in_channels, out_channels, kernel=kernel_size, padding=padding) - output = block(x) - assert output.shape == (1, out_channels, 62, 66) - - # Test with different stride - stride = 2 - block = SeparableConvBlock(in_channels, out_channels, stride=stride) - output = block(x) - assert output.shape == (1, out_channels, 32, 32) - - # Test with different output channels - out_channels = 32 - block = SeparableConvBlock(in_channels, out_channels) - output = block(x) - assert output.shape == (1, out_channels, 64, 64) - - # Test without normalization and activation - block = SeparableConvBlock(in_channels, out_channels, norm=None, act=None) - output = block(x) - assert output.shape == (1, out_channels, 64, 64) - - -def test_SeparableConvBlock_build_activation(): - # Test build_activation() function with different activation functions - activation_fns = [ - {'name': 'ReLU'}, - {'name': 'Sigmoid'}, - {'name': 'Tanh'}, - {'name': 'LeakyReLU', 'negative_slope': 0.2} - ] - for act in activation_fns: - block = SeparableConvBlock(64, 64, act=act) - assert isinstance(block.act, nn.Module) - - -def test_SeparableConvBlock_build_norm(): - # Test build_norm() function with different normalization layers - norm_layers = [ - {'name': 'BatchNorm2d', 'num_features': 64}, - {'name': 'InstanceNorm2d', 'num_features': 64}, - {'name': 'GroupNorm', 'num_groups': 8, 'num_channels': 64}, - ] - for norm in norm_layers: - block = SeparableConvBlock(64, 64, norm=norm) - assert isinstance(block.norm, nn.Module) diff --git a/tests/nn/test_cnn.py b/tests/nn/test_cnn.py deleted file mode 100644 index 67b8686..0000000 --- a/tests/nn/test_cnn.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -from chameleon.nn import CNN2Dcell - - -@pytest.fixture -def input_tensor(): - return torch.randn((2, 3, 32, 32)) - - -@pytest.fixture -def output_shape(): - return (2, 16, 32, 32) - - -def test_cnn2dcell_forward(input_tensor, output_shape): - model = CNN2Dcell(in_channels=3, out_channels=16) - output = model(input_tensor) - assert output.shape == output_shape - - -def test_cnn2dcell_with_activation(input_tensor, output_shape): - model = CNN2Dcell(in_channels=3, out_channels=16, act={'name': 'ReLU', 'inplace': True}) - output = model(input_tensor) - assert output.shape == output_shape - assert torch.all(output >= 0) - - -def test_cnn2dcell_with_batch_norm(input_tensor, output_shape): - model = CNN2Dcell(in_channels=3, out_channels=16, norm={'name': 'BatchNorm2d', 'num_features': 16}) - output = model(input_tensor) - assert output.shape == output_shape - assert torch.allclose(output.mean(dim=(0, 2, 3)), torch.zeros(16), rtol=1e-3, atol=1e-5) - assert torch.allclose(output.var(dim=(0, 2, 3)), torch.ones(16), rtol=1e-3, atol=1e-5) - - -def test_cnn2dcell_with_dropout(input_tensor, output_shape): - model = CNN2Dcell(in_channels=3, out_channels=16, dropout={'name': 'Dropout2d', 'p': 0.5}) - output = model(input_tensor) - assert output.shape == output_shape - - -def test_cnn2dcell_with_pooling(input_tensor): - model = CNN2Dcell(in_channels=3, out_channels=16, pool=nn.AdaptiveAvgPool2d(1)) - output = model(input_tensor) - assert output.shape == (2, 16, 1, 1) - - -def test_cnn2dcell_init_type(input_tensor): - model = CNN2Dcell(in_channels=3, out_channels=16, init_type='uniform') - output1 = model(input_tensor) - model = CNN2Dcell(in_channels=3, out_channels=16, init_type='normal') - output2 = model(input_tensor) - assert not torch.allclose(output1, output2, rtol=1e-3, atol=1e-5) - - -def test_cnn2dcell_all_together(input_tensor): - model = CNN2Dcell(in_channels=3, out_channels=16, - kernel=5, stride=2, padding=2, dilation=2, groups=1, - bias=True, padding_mode='reflect', - norm={'name': 'BatchNorm2d', 'num_features': 16, 'momentum': 0.5}, - dropout={'name': 'Dropout2d', 'p': 0.5}, - act={'name': 'LeakyReLU', 'negative_slope': 0.1, 'inplace': True}, - pool=nn.AdaptiveAvgPool2d(1), - init_type='uniform') - output = model(input_tensor) - assert output.shape == (2, 16, 1, 1) diff --git a/tests/nn/test_mbcnn.py b/tests/nn/test_mbcnn.py deleted file mode 100644 index d66772e..0000000 --- a/tests/nn/test_mbcnn.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import torch.nn as nn - -from chameleon.nn import MBCNNcell - - -def test_mbcnncell_identity(): - # Test identity block - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell(16, 16, kernel=3, stride=1) - out = block(x) - assert out.shape == x.shape - - -def test_mbcnncell_expdim(): - # Test expansion block - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell(16, 32, kernel=3, stride=1) - out = block(x) - assert out.shape == (1, 32, 32, 32) - - -def test_mbcnncell_norm(): - # Test block with normalization layer - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell(16, 16, kernel=3, stride=1, norm=nn.BatchNorm2d(16)) - out = block(x) - assert out.shape == x.shape - - -def test_mbcnncell_se(): - # Test block with Squeeze-and-Excitation layer - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell(16, 16, kernel=3, stride=1, use_se=True) - out = block(x) - assert out.shape == x.shape - - -def test_mbcnncell_build_mbv1block(): - # Test building of MobileNetV1-style block - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell.build_mbv1block(16, 32) - out = block(x) - assert out.shape == (1, 32, 32, 32) - - -def test_mbcnncell_build_mbv2block(): - # Test building of MobileNetV2-style block - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell.build_mbv2block(16, 32) - out = block(x) - assert out.shape == (1, 32, 32, 32) - - -def test_mbcnncell_build_mbv3block(): - # Test building of MobileNetV3-style block - x = torch.randn(1, 16, 32, 32) - block = MBCNNcell.build_mbv3block(16, 32) - out = block(x) - assert out.shape == (1, 32, 32, 32) diff --git a/tests/tools/test_calflops.py b/tests/tools/test_calflops.py new file mode 100644 index 0000000..a10e6df --- /dev/null +++ b/tests/tools/test_calflops.py @@ -0,0 +1,12 @@ +import pytest + +from chameleon.modules import build_backbone +from chameleon.tools import calculate_flops + + +def test_calcualte_flops(): + model = build_backbone('resnet50') + flops, macs, params = calculate_flops(model, (1, 3, 224, 224)) + assert flops == '8.21 GFLOPS' + assert macs == '4.09 GMACs' + assert params == '25.56 M' diff --git a/tests/transformers/testMetaFormer.py b/tests/transformers/testMetaFormer.py deleted file mode 100644 index a23f5ba..0000000 --- a/tests/transformers/testMetaFormer.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -import torch -from torch import nn - -from chameleon.transformers.metaformer import MetaFormer, MetaFormerBlock - - -def test_init(): - model = MetaFormer() - assert isinstance(model, nn.Module) - assert isinstance(model, MetaFormer) - - -def test_forward(): - model = MetaFormer() - input_tensor = torch.rand(3, 3, 224, 224) - all_hidden_state = model(input_tensor) - assert all_hidden_state[0].shape == torch.Size([3, 64, 56, 56]) - assert all_hidden_state[1].shape == torch.Size([3, 128, 28, 28]) - assert all_hidden_state[2].shape == torch.Size([3, 320, 14, 14]) - assert all_hidden_state[3].shape == torch.Size([3, 512, 7, 7]) - - -def test_token_mixer(): - model = MetaFormer(token_mixers=[ - {'name': 'AttentionMixing', 'in_features': 64}, - {'name': 'PoolMixing', 'pool_size': 5}, - {'name': 'RandomMixing', 'num_tokens': 196}, - {'name': 'SepConvMixing', 'in_features': 512, 'expand_ratio': 4} - ]) - input_tensor = torch.rand(3, 3, 224, 224) - all_hidden_state = model(input_tensor) - assert all_hidden_state[0].shape == torch.Size([3, 64, 56, 56]) - assert all_hidden_state[1].shape == torch.Size([3, 128, 28, 28]) - assert all_hidden_state[2].shape == torch.Size([3, 320, 14, 14]) - assert all_hidden_state[3].shape == torch.Size([3, 512, 7, 7]) - - -@pytest.fixture -def input_tensor(): - return torch.randn(2, 3, 16, 16) - - -@pytest.fixture -def metaformer_block(): - return MetaFormerBlock(3) - - -def test_metaformer_block_forward(metaformer_block, input_tensor): - output = metaformer_block(input_tensor) - assert output.shape == input_tensor.shape diff --git a/tests/transformers/testMobileViT.py b/tests/transformers/testMobileViT.py deleted file mode 100644 index afb137c..0000000 --- a/tests/transformers/testMobileViT.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from torch import nn -from transformers import MobileViTConfig, MobileViTModel - -from chameleon import MobileViT - - -def test_init(): - model = MobileViT() - assert isinstance(model, nn.Module) - assert isinstance(model.config, MobileViTConfig) - assert isinstance(model.model, MobileViTModel) - - -def test_forward(): - model = MobileViT() - input_tensor = torch.rand(1, 3, 224, 224) - all_hidden_state = model(input_tensor) - assert all_hidden_state[0].shape == torch.Size([1, 32, 112, 112]) - assert all_hidden_state[1].shape == torch.Size([1, 64, 56, 56]) - assert all_hidden_state[2].shape == torch.Size([1, 96, 28, 28]) - assert all_hidden_state[3].shape == torch.Size([1, 128, 14, 14]) - assert all_hidden_state[4].shape == torch.Size([1, 160, 7, 7]) - - -def test_list_pretrained_models(): - models = MobileViT.list_models() - assert isinstance(models, list) - assert len(models) > 0 - - -def test_from_pretrained(): - model = MobileViT.from_pretrained('apple/mobilevit-small') - input_tensor = torch.rand(1, 3, 224, 224) - all_hidden_state = model(input_tensor) - assert all_hidden_state[0].shape == torch.Size([1, 32, 112, 112]) - assert all_hidden_state[1].shape == torch.Size([1, 64, 56, 56]) - assert all_hidden_state[2].shape == torch.Size([1, 96, 28, 28]) - assert all_hidden_state[3].shape == torch.Size([1, 128, 14, 14]) - assert all_hidden_state[4].shape == torch.Size([1, 160, 7, 7]) diff --git a/tests/transformers/testPoolFormer.py b/tests/transformers/testPoolFormer.py deleted file mode 100644 index fb72169..0000000 --- a/tests/transformers/testPoolFormer.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from torch import nn -from transformers import PoolFormerConfig, PoolFormerModel - -from chameleon import PoolFormer - - -def test_init(): - model = PoolFormer() - assert isinstance(model, nn.Module) - assert isinstance(model.config, PoolFormerConfig) - assert isinstance(model.model, PoolFormerModel) - - -def test_forward(): - model = PoolFormer() - input_tensor = torch.rand(1, 3, 224, 224) - all_hidden_state = model(input_tensor) - assert all_hidden_state[0].shape == torch.Size([1, 64, 56, 56]) - assert all_hidden_state[1].shape == torch.Size([1, 128, 28, 28]) - assert all_hidden_state[2].shape == torch.Size([1, 320, 14, 14]) - assert all_hidden_state[3].shape == torch.Size([1, 512, 7, 7]) - - -def test_list_pretrained_models(): - models = PoolFormer.list_models() - assert isinstance(models, list) - assert len(models) > 0 - - -def test_from_pretrained(): - model = PoolFormer.from_pretrained('sail/poolformer_s12') - input_tensor = torch.rand(1, 3, 224, 224) - all_hidden_state = model(input_tensor) - assert all_hidden_state[0].shape == torch.Size([1, 64, 56, 56]) - assert all_hidden_state[1].shape == torch.Size([1, 128, 28, 28]) - assert all_hidden_state[2].shape == torch.Size([1, 320, 14, 14]) - assert all_hidden_state[3].shape == torch.Size([1, 512, 7, 7]) diff --git a/tests/transformers/testViT.py b/tests/transformers/testViT.py deleted file mode 100644 index 4a6fbfd..0000000 --- a/tests/transformers/testViT.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from torch import nn -from transformers import ViTConfig, ViTModel - -from chameleon import ViT - - -def test_init(): - model = ViT() - assert isinstance(model, nn.Module) - assert isinstance(model.config, ViTConfig) - assert isinstance(model.model, ViTModel) - - -def test_forward(): - model = ViT() - input_tensor = torch.rand(1, 3, 224, 224) - cls_token, hidden_state = model(input_tensor) - assert cls_token.shape == torch.Size([1, 768]) - assert hidden_state.shape == torch.Size([1, 196, 768]) - - -def test_list_pretrained_models(): - models = ViT.list_models() - assert isinstance(models, list) - assert len(models) > 0 - - -def test_from_pretrained(): - model = ViT.from_pretrained('google/vit-base-patch16-224') - input_tensor = torch.rand(1, 3, 224, 224) - cls_token, hidden_state = model(input_tensor) - assert cls_token.shape == torch.Size([1, 768]) - assert hidden_state.shape == torch.Size([1, 196, 768]) diff --git a/tests/transformers/test_build_transformers.py b/tests/transformers/test_build_transformers.py deleted file mode 100644 index fe6d7f5..0000000 --- a/tests/transformers/test_build_transformers.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest -import torch - -from chameleon import build_transformer, list_transformer -from chameleon.transformers import (BASE_TRANSFORMER_NAMES, TRANSFORMER, - EfficientFormer, MetaFormer, MobileViT, - PoolFormer, ViT) - - -def test_list_transformer(): - models = list_transformer() - assert len(models) == len(TRANSFORMER) - - -@pytest.mark.parametrize("model_name", BASE_TRANSFORMER_NAMES.keys()) -def test_build_transformer(model_name): - model = build_transformer(model_name) - assert isinstance(model, (ViT, MobileViT, PoolFormer, MetaFormer, EfficientFormer)) diff --git a/tests/transformers/test_utils_in_trans.py b/tests/transformers/test_utils_in_trans.py deleted file mode 100644 index 115af28..0000000 --- a/tests/transformers/test_utils_in_trans.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from chameleon import calculate_patch_size - - -def test_calculate_patch_size(): - - # Test case 1 - image_size = (256, 256) - num_patches = (4, 4) - expected_patch_size = (64, 64) - assert calculate_patch_size(image_size, num_patches) == expected_patch_size - - # Test case 2 - image_size = (512, 512) - num_patches = (8, 8) - expected_patch_size = (64, 64) - - assert calculate_patch_size(image_size, num_patches) == expected_patch_size - - # Test case 3 - invalid input - image_size = (512, 512) - num_patches = (7, 7) - - with pytest.raises(ValueError): - calculate_patch_size(image_size, num_patches)