Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions chameleon/base/blocks/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def __init__(

bias = False if norm is not None else bias

self.block = nn.ModuleDict()

self.block['dw_conv'] = nn.Conv2d(
self.dw_conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size=kernel,
Expand All @@ -69,7 +67,7 @@ def __init__(
groups=in_channels,
bias=False,
)
self.block['pw_conv'] = nn.Conv2d(
self.pw_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
Expand All @@ -78,18 +76,22 @@ def __init__(
bias=bias,
)
if inner_norm is not None:
self.block['inner_norm'] = COMPONENTS.build(inner_norm) if isinstance(inner_norm, dict) else inner_norm
self.inner_norm = COMPONENTS.build(inner_norm) if isinstance(inner_norm, dict) else inner_norm
if inner_act is not None:
self.block['inner_act'] = COMPONENTS.build(inner_act) if isinstance(inner_act, dict) else inner_act
self.inner_act = COMPONENTS.build(inner_act) if isinstance(inner_act, dict) else inner_act
if norm is not None:
self.block['norm'] = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
self.norm = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
if act is not None:
self.block['act'] = COMPONENTS.build(act) if isinstance(act, dict) else act
self.act = COMPONENTS.build(act) if isinstance(act, dict) else act
self.initialize_weights_(init_type)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for _, m in self.block.items():
x = m(x)
x = self.dw_conv(x)
x = self.inner_norm(x) if hasattr(self, 'inner_norm') else x
x = self.inner_act(x) if hasattr(self, 'inner_act') else x
x = self.pw_conv(x)
x = self.norm(x) if hasattr(self, 'norm') else x
x = self.act(x) if hasattr(self, 'act') else x
return x


Expand Down Expand Up @@ -151,25 +153,27 @@ def __init__(
Options = {'normal', 'uniform'}

Examples for using norm, act, and pool:
1. conv_block = Conv2dBlock(in_channels=3,
out_channels=12,
norm=nn.BatchNorm2d(12),
act=nn.ReLU(),
pool=nn.AdaptiveAvgPool2d(1))
2. conv_block = Conv2dBlock(in_channels=3,
out_channels=12,
norm={'name': 'BatchNorm2d', 'num_features': 12},
act={'name': 'ReLU', 'inplace': True})
1. conv_block = Conv2dBlock(
in_channels=3,
out_channels=12,
norm=nn.BatchNorm2d(12),
act=nn.ReLU(),
pool=nn.AdaptiveAvgPool2d(1)
)
2. conv_block = Conv2dBlock(
in_channels=3,
out_channels=12,
norm={'name': 'BatchNorm2d', 'num_features': 12},
act={'name': 'ReLU', 'inplace': True},
)

Attributes:
block (nn.ModuleDict): a model block.
block (nn.Module): a model block.
"""
super().__init__()
self.block = nn.ModuleDict()

bias = False if norm is not None else bias

self.block['conv'] = nn.Conv2d(
self.conv = nn.Conv2d(
int(in_channels),
int(out_channels),
kernel_size=kernel,
Expand All @@ -181,13 +185,14 @@ def __init__(
padding_mode=padding_mode,
)
if norm is not None:
self.block['norm'] = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
self.norm = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
if act is not None:
self.block['act'] = COMPONENTS.build(act) if isinstance(act, dict) else act
self.act = COMPONENTS.build(act) if isinstance(act, dict) else act

self.initialize_weights_(init_type)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for _, m in self.block.items():
x = m(x)
x = self.conv(x)
x = self.norm(x) if hasattr(self, 'norm') else x
x = self.act(x) if hasattr(self, 'act') else x
return x
17 changes: 8 additions & 9 deletions chameleon/base/layers/aspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class ASPP(PowerModule):

ARCHS = {
# ksize, stride, padding, dilation, is_use_hs
'DILATE1': [3, 1, 1, 1, True],
'DILATE2': [3, 1, 2, 2, True],
'DILATE3': [3, 1, 4, 4, True],
'DILATE4': [3, 1, 8, 8, True],
'dilate_layer1': [3, 1, 1, 1, True],
'dilate_layer2': [3, 1, 2, 2, True],
'dilate_layer3': [3, 1, 4, 4, True],
'dilate_layer4': [3, 1, 8, 8, True],
}

def __init__(
Expand All @@ -43,8 +43,7 @@ def __init__(
Activation function for the output layer. Defaults to ReLU.
"""
super().__init__()
self.layers = nn.ModuleDict()
for dilate_name, cfg in self.ARCHS.items():
for name, cfg in self.ARCHS.items():
ksize, stride, padding, dilation, use_hs = cfg
layer = BLOCKS.build(
{
Expand All @@ -59,12 +58,12 @@ def __init__(
'act': COMPONENTS.build({'name': 'Hswish' if use_hs else 'ReLU'}),
}
)
self.layers[dilate_name] = layer
self.add_module(name, layer)

self.output_layer = BLOCKS.build(
{
'name': 'Conv2dBlock',
'in_channels': in_channels * len(self.layers),
'in_channels': in_channels * len(self.ARCHS),
'out_channels': out_channels,
'kernel': 1,
'stride': 1,
Expand All @@ -75,7 +74,7 @@ def __init__(
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
outputs = [layer(x) for layer in self.layers.values()]
outputs = [getattr(self, name)(x) for name in self.ARCHS.keys()]
outputs = torch.cat(outputs, dim=1)
outputs = self.output_layer(outputs)
return outputs
16 changes: 9 additions & 7 deletions tests/base/blocks/test_conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_SeparableConv2dBlock_build_component():
]
for norm, tgt in zip(norm_layers, tgt_norms):
block = SeparableConv2dBlock(64, 64, norm=norm)
assert isinstance(block.block['norm'], tgt)
assert isinstance(block.norm, tgt)


@pytest.fixture
Expand Down Expand Up @@ -127,11 +127,13 @@ def test_Conv2dBlock_init_type(input_tensor):


def test_Conv2dBlock_all_together(input_tensor):
model = Conv2dBlock(in_channels=3, out_channels=16,
kernel=5, stride=2, padding=2, dilation=2, groups=1,
bias=True, padding_mode='reflect',
norm={'name': 'BatchNorm2d', 'num_features': 16, 'momentum': 0.5},
act={'name': 'LeakyReLU', 'negative_slope': 0.1, 'inplace': True},
init_type='uniform')
model = Conv2dBlock(
in_channels=3, out_channels=16,
kernel=5, stride=2, padding=2, dilation=2, groups=1,
bias=True, padding_mode='reflect',
norm={'name': 'BatchNorm2d', 'num_features': 16, 'momentum': 0.5},
act={'name': 'LeakyReLU', 'negative_slope': 0.1, 'inplace': True},
init_type='uniform'
)
output = model(input_tensor)
assert output.shape == (2, 16, 14, 14)
8 changes: 4 additions & 4 deletions tests/base/layers/test_aspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def test_aspp_layer(input_tensor):

# Test with different dilation rates
aspp_layer = ASPP(in_channels, out_channels)
aspp_layer.layers['DILATE1'].dilation = (2, 2)
aspp_layer.layers['DILATE2'].dilation = (4, 4)
aspp_layer.layers['DILATE3'].dilation = (8, 8)
aspp_layer.layers['DILATE4'].dilation = (16, 16)
aspp_layer.dilate_layer1.dilation = (2, 2)
aspp_layer.dilate_layer2.dilation = (4, 4)
aspp_layer.dilate_layer3.dilation = (8, 8)
aspp_layer.dilate_layer4.dilation = (16, 16)
output = aspp_layer(input_tensor)
assert output.size() == (1, out_channels, 32, 32)
4 changes: 2 additions & 2 deletions tests/base/layers/test_selayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def test_selayer_reduction():

expected_channels = in_channels // reduction

assert se_layer.fc1.block['conv'].out_channels == expected_channels
assert se_layer.fc2.block['conv'].out_channels == in_channels
assert se_layer.fc1.conv.out_channels == expected_channels
assert se_layer.fc2.conv.out_channels == in_channels