Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions predict.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python predict_segmentation.py --image_dir="./dataset/test" --checkpoint_dir="./checkpoint"
73 changes: 51 additions & 22 deletions predict_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,48 @@
import numpy as np
slim = tf.contrib.slim

image_dir = './dataset/test/'
images_list = sorted([os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.png')])

checkpoint_dir = "./checkpoint_mfb"
#==============INPUT ARGUMENTS==================
flags = tf.app.flags

#Parameters
flags.DEFINE_integer('num_classes', 12, 'The number of classes to predict.')
flags.DEFINE_integer('batch_size', 10, 'The batch_size for training.')

#Architectural changes
flags.DEFINE_integer('num_initial_blocks', 1, 'The number of initial blocks to use in ENet.')
flags.DEFINE_integer('stage_two_repeat', 2, 'The number of times to repeat stage two.')
flags.DEFINE_boolean('skip_connections', False, 'If True, perform skip connections from encoder to decoder.')

#Directory arguments
flags.DEFINE_string('checkpoint_dir', './checkpoint_mfb', 'The checkpoint directory to restore your model.')
flags.DEFINE_string('image_dir', './dataset/test', 'The image directory to find the images to predict.')

FLAGS = flags.FLAGS

#==========NAME HANDLING FOR CONVENIENCE==============
num_classes = FLAGS.num_classes
batch_size = FLAGS.batch_size

#Architectural changes
num_initial_blocks = FLAGS.num_initial_blocks
skip_connections = FLAGS.skip_connections
stage_two_repeat = FLAGS.stage_two_repeat

#Directories
image_dir = FLAGS.image_dir
checkpoint_dir = FLAGS.checkpoint_dir

#Directories
#Get the list of images to predict in image_dir
images_list = sorted([os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.png')])
#Get latest checkpoint in checkpoint_dir
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
#Create the photo directory
photo_dir = checkpoint_dir + "/test_images"
if not os.path.exists(photo_dir):
os.mkdir(photo_dir)

num_initial_blocks = 1
skip_connections = False
stage_two_repeat = 2
'''
#Labels to colours are obtained from here:
https://github.com/alexgkendall/SegNet-Tutorial/blob/c922cc4a4fcc7ce279dd998fb2d4a8703f34ebd7/Scripts/test_segmentation_camvid.py
Expand Down Expand Up @@ -51,15 +84,10 @@
10: [0,128,192],
11: [0,0,0]}

#Create the photo directory
photo_dir = checkpoint_dir + "/test_images"
if not os.path.exists(photo_dir):
os.mkdir(photo_dir)

#Create a function to convert each pixel label to colour.
def grayscale_to_colour(image):
print 'Converting image...'
image = image.reshape((360, 480, 1))
image = image.reshape((image.shape[0], image.shape[1], 1))
image = np.repeat(image, 3, axis=-1)
for i in xrange(image.shape[0]):
for j in xrange(image.shape[1]):
Expand All @@ -77,13 +105,13 @@ def grayscale_to_colour(image):
# image = tf.image.resize_image_with_crop_or_pad(image, 360, 480)
# image = tf.cast(image, tf.float32)
image = preprocess(image)
images = tf.train.batch([image], batch_size = 10, allow_smaller_final_batch=True)
images = tf.train.batch([image], batch_size = batch_size, allow_smaller_final_batch=True)

#Create the model inference
with slim.arg_scope(ENet_arg_scope()):
logits, probabilities = ENet(images,
num_classes=12,
batch_size=10,
num_classes=num_classes,
batch_size=batch_size,
is_training=True,
reuse=None,
num_initial_blocks=num_initial_blocks,
Expand All @@ -100,21 +128,22 @@ def restore_fn(sess):
print 'HERE', predictions.get_shape()

sv = tf.train.Supervisor(logdir=None, init_fn=restore_fn)

with sv.managed_session() as sess:

for i in xrange(len(images_list) / 10 + 1):
for i in xrange(len(images_list) / batch_size + 1):
segmentations = sess.run(predictions)
# print segmentations.shape

for j in xrange(segmentations.shape[0]):
#Stop at the 233rd image as it's repeated
if i*10 + j == 223:
current_image = i * batch_size + j

if current_image > len(images_list) - 1:
break

converted_image = grayscale_to_colour(segmentations[j])
print 'Saving image %s/%s' %(i*10 + j, len(images_list))
print 'Saving image %s/%s' %(current_image, len(images_list) - 1)
plt.axis('off')
plt.imshow(converted_image)
imsave(photo_dir + "/image_%s.png" %(i*10 + j), converted_image)
# plt.show()
imsave(photo_dir + "/image_%s.png" %(current_image), converted_image)
# plt.show()