Hello,I found a performance issue in the definition of get_dataset ,
prepare_data.py,
image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map was called without num_parallel_calls.
I think it will increase the efficiency of your program if you add this.
Here is the documemtation of tensorflow to support this thing.
Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.