diff --git a/enet.py b/enet.py index fe4cc2a..74e4108 100644 --- a/enet.py +++ b/enet.py @@ -111,11 +111,11 @@ def unpool(updates, mask, k_size=[1, 2, 2, 1], output_shape=None, scope=''): return ret @slim.add_arg_scope -def initial_block(inputs, is_training=True, scope='initial_block'): +def initial_block(inputs, num_channels = 3, is_training=True, scope='initial_block'): ''' The initial block for Enet has 2 branches: The convolution branch and Maxpool branch. - The conv branch has 13 layers, while the maxpool branch gives 3 layers corresponding to the RGB channels. + The conv branch has (16 - num_channels) layers, while the maxpool branch gives layers equal to the number of channels. Both output layers are then concatenated to give an output of 16 layers. NOTE: Does not need to store pooling indices since it won't be used later for the final upsampling. @@ -127,7 +127,7 @@ def initial_block(inputs, is_training=True, scope='initial_block'): - net_concatenated(Tensor): a 4D Tensor that contains the ''' #Convolutional branch - net_conv = slim.conv2d(inputs, 13, [3,3], stride=2, activation_fn=None, scope=scope+'_conv') + net_conv = slim.conv2d(inputs, 16 - num_channels, [3,3], stride=2, activation_fn=None, scope=scope+'_conv') net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm') net_conv = prelu(net_conv, scope=scope+'_prelu') @@ -387,6 +387,7 @@ def bottleneck(inputs, def ENet(inputs, num_classes, batch_size, + num_channels=3, num_initial_blocks=1, stage_two_repeat=2, skip_connections=True, @@ -421,8 +422,9 @@ def ENet(inputs, slim.arg_scope([slim.batch_norm], fused=True), \ slim.arg_scope([slim.conv2d, slim.conv2d_transpose], activation_fn=None): #=================INITIAL BLOCK================= - for i in xrange(1, max(num_initial_blocks, 1) + 1): - net = initial_block(inputs, scope='initial_block_' + str(i)) + net = initial_block(inputs, scope='initial_block_1') + for i in xrange(2, max(num_initial_blocks, 1) + 1): + net = initial_block(net, num_channels = num_channels, scope='initial_block_' + str(i)) #Save for skip connection later if skip_connections: @@ -512,4 +514,4 @@ def ENet_arg_scope(weight_decay=2e-4, with slim.arg_scope([slim.batch_norm], decay=batch_norm_decay, epsilon=batch_norm_epsilon) as scope: - return scope \ No newline at end of file + return scope