Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions enet.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def unpool(updates, mask, k_size=[1, 2, 2, 1], output_shape=None, scope=''):
return ret

@slim.add_arg_scope
def initial_block(inputs, is_training=True, scope='initial_block'):
def initial_block(inputs, num_channels = 3, is_training=True, scope='initial_block'):
'''
The initial block for Enet has 2 branches: The convolution branch and Maxpool branch.

The conv branch has 13 layers, while the maxpool branch gives 3 layers corresponding to the RGB channels.
The conv branch has (16 - num_channels) layers, while the maxpool branch gives layers equal to the number of channels.
Both output layers are then concatenated to give an output of 16 layers.

NOTE: Does not need to store pooling indices since it won't be used later for the final upsampling.
Expand All @@ -127,7 +127,7 @@ def initial_block(inputs, is_training=True, scope='initial_block'):
- net_concatenated(Tensor): a 4D Tensor that contains the
'''
#Convolutional branch
net_conv = slim.conv2d(inputs, 13, [3,3], stride=2, activation_fn=None, scope=scope+'_conv')
net_conv = slim.conv2d(inputs, 16 - num_channels, [3,3], stride=2, activation_fn=None, scope=scope+'_conv')
net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm')
net_conv = prelu(net_conv, scope=scope+'_prelu')

Expand Down Expand Up @@ -387,6 +387,7 @@ def bottleneck(inputs,
def ENet(inputs,
num_classes,
batch_size,
num_channels=3,
num_initial_blocks=1,
stage_two_repeat=2,
skip_connections=True,
Expand Down Expand Up @@ -421,8 +422,9 @@ def ENet(inputs,
slim.arg_scope([slim.batch_norm], fused=True), \
slim.arg_scope([slim.conv2d, slim.conv2d_transpose], activation_fn=None):
#=================INITIAL BLOCK=================
for i in xrange(1, max(num_initial_blocks, 1) + 1):
net = initial_block(inputs, scope='initial_block_' + str(i))
net = initial_block(inputs, scope='initial_block_1')
for i in xrange(2, max(num_initial_blocks, 1) + 1):
net = initial_block(net, num_channels = num_channels, scope='initial_block_' + str(i))

#Save for skip connection later
if skip_connections:
Expand Down Expand Up @@ -512,4 +514,4 @@ def ENet_arg_scope(weight_decay=2e-4,
with slim.arg_scope([slim.batch_norm],
decay=batch_norm_decay,
epsilon=batch_norm_epsilon) as scope:
return scope
return scope