diff --git a/predict.sh b/predict.sh new file mode 100755 index 0000000..cd52527 --- /dev/null +++ b/predict.sh @@ -0,0 +1 @@ +python predict_segmentation.py --image_dir="./dataset/test" --checkpoint_dir="./checkpoint" diff --git a/predict_segmentation.py b/predict_segmentation.py index cb3de21..93f3ab7 100644 --- a/predict_segmentation.py +++ b/predict_segmentation.py @@ -7,15 +7,48 @@ import numpy as np slim = tf.contrib.slim -image_dir = './dataset/test/' -images_list = sorted([os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.png')]) -checkpoint_dir = "./checkpoint_mfb" +#==============INPUT ARGUMENTS================== +flags = tf.app.flags + +#Parameters +flags.DEFINE_integer('num_classes', 12, 'The number of classes to predict.') +flags.DEFINE_integer('batch_size', 10, 'The batch_size for training.') + +#Architectural changes +flags.DEFINE_integer('num_initial_blocks', 1, 'The number of initial blocks to use in ENet.') +flags.DEFINE_integer('stage_two_repeat', 2, 'The number of times to repeat stage two.') +flags.DEFINE_boolean('skip_connections', False, 'If True, perform skip connections from encoder to decoder.') + +#Directory arguments +flags.DEFINE_string('checkpoint_dir', './checkpoint_mfb', 'The checkpoint directory to restore your model.') +flags.DEFINE_string('image_dir', './dataset/test', 'The image directory to find the images to predict.') + +FLAGS = flags.FLAGS + +#==========NAME HANDLING FOR CONVENIENCE============== +num_classes = FLAGS.num_classes +batch_size = FLAGS.batch_size + +#Architectural changes +num_initial_blocks = FLAGS.num_initial_blocks +skip_connections = FLAGS.skip_connections +stage_two_repeat = FLAGS.stage_two_repeat + +#Directories +image_dir = FLAGS.image_dir +checkpoint_dir = FLAGS.checkpoint_dir + +#Directories +#Get the list of images to predict in image_dir +images_list = sorted([os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.png')]) +#Get latest checkpoint in checkpoint_dir checkpoint = tf.train.latest_checkpoint(checkpoint_dir) +#Create the photo directory +photo_dir = checkpoint_dir + "/test_images" +if not os.path.exists(photo_dir): + os.mkdir(photo_dir) -num_initial_blocks = 1 -skip_connections = False -stage_two_repeat = 2 ''' #Labels to colours are obtained from here: https://github.com/alexgkendall/SegNet-Tutorial/blob/c922cc4a4fcc7ce279dd998fb2d4a8703f34ebd7/Scripts/test_segmentation_camvid.py @@ -51,15 +84,10 @@ 10: [0,128,192], 11: [0,0,0]} -#Create the photo directory -photo_dir = checkpoint_dir + "/test_images" -if not os.path.exists(photo_dir): - os.mkdir(photo_dir) - #Create a function to convert each pixel label to colour. def grayscale_to_colour(image): print 'Converting image...' - image = image.reshape((360, 480, 1)) + image = image.reshape((image.shape[0], image.shape[1], 1)) image = np.repeat(image, 3, axis=-1) for i in xrange(image.shape[0]): for j in xrange(image.shape[1]): @@ -77,13 +105,13 @@ def grayscale_to_colour(image): # image = tf.image.resize_image_with_crop_or_pad(image, 360, 480) # image = tf.cast(image, tf.float32) image = preprocess(image) - images = tf.train.batch([image], batch_size = 10, allow_smaller_final_batch=True) + images = tf.train.batch([image], batch_size = batch_size, allow_smaller_final_batch=True) #Create the model inference with slim.arg_scope(ENet_arg_scope()): logits, probabilities = ENet(images, - num_classes=12, - batch_size=10, + num_classes=num_classes, + batch_size=batch_size, is_training=True, reuse=None, num_initial_blocks=num_initial_blocks, @@ -100,21 +128,22 @@ def restore_fn(sess): print 'HERE', predictions.get_shape() sv = tf.train.Supervisor(logdir=None, init_fn=restore_fn) - + with sv.managed_session() as sess: - for i in xrange(len(images_list) / 10 + 1): + for i in xrange(len(images_list) / batch_size + 1): segmentations = sess.run(predictions) # print segmentations.shape for j in xrange(segmentations.shape[0]): - #Stop at the 233rd image as it's repeated - if i*10 + j == 223: + current_image = i * batch_size + j + + if current_image > len(images_list) - 1: break converted_image = grayscale_to_colour(segmentations[j]) - print 'Saving image %s/%s' %(i*10 + j, len(images_list)) + print 'Saving image %s/%s' %(current_image, len(images_list) - 1) plt.axis('off') plt.imshow(converted_image) - imsave(photo_dir + "/image_%s.png" %(i*10 + j), converted_image) - # plt.show() \ No newline at end of file + imsave(photo_dir + "/image_%s.png" %(current_image), converted_image) + # plt.show()