Diff of /uNet_Subclassed.py [000000] .. [3b7fea]

Switch to side-by-side view

--- a
+++ b/uNet_Subclassed.py
@@ -0,0 +1,688 @@
+# %% importing packages
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras import mixed_precision
+from tensorflow.python.ops.numpy_ops import np_config
+np_config.enable_numpy_behavior()
+from skimage import measure
+import cv2 as cv
+import os
+import matplotlib.pyplot as plt
+plt.rcParams['figure.figsize'] = [5, 5]
+# you can alternatively call this script using this line in the terminal to
+# address the issue of memory leak when using the dataset.shuffle buffer. Found
+# at the subsequent link.
+# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4.5.9 python3 uNet_Subclassed.py
+
+# https://stackoverflow.com/questions/55211315/memory-leak-with-tf-data/66971031#66971031
+
+
+# %% Citations
+#############################################################
+#############################################################
+# https://www.tensorflow.org/guide/keras/functional
+# https://www.tensorflow.org/tutorials/customization/custom_layers
+# https://keras.io/examples/keras_recipes/tfrecord/
+# https://arxiv.org/abs/1505.04597
+# https://www.tensorflow.org/guide/gpu
+
+# Defining Functions
+#############################################################
+#############################################################
+
+def parse_tf_elements(element):
+    '''This function is the mapper function for retrieving examples from the
+       tfrecord'''
+
+    # create placeholders for all the features in each example
+    data = {
+        'height' : tf.io.FixedLenFeature([],tf.int64),
+        'width' : tf.io.FixedLenFeature([],tf.int64),
+        'raw_image' : tf.io.FixedLenFeature([],tf.string),
+        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
+        'bbox_x' : tf.io.VarLenFeature(tf.float32),
+        'bbox_y' : tf.io.VarLenFeature(tf.float32),
+        'bbox_height' : tf.io.VarLenFeature(tf.float32),
+        'bbox_width' : tf.io.VarLenFeature(tf.float32)
+    }
+
+    # pull out the current example
+    content = tf.io.parse_single_example(element, data)
+
+    # pull out each feature from the example 
+    height = content['height']
+    width = content['width']
+    raw_seg = content['raw_seg']
+    raw_image = content['raw_image']
+    bbox_x = content['bbox_x']
+    bbox_y = content['bbox_y']
+    bbox_height = content['bbox_height']
+    bbox_width = content['bbox_width']
+
+    # convert the images to uint8, and reshape them accordingly
+    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
+    image = tf.reshape(image,shape=[height,width,3])
+    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)
+    segmentation = tf.reshape(segmentation,shape=[height,width,1])
+    one_hot_seg = tf.one_hot(tf.squeeze(segmentation),7,axis=-1)
+
+
+
+
+    
+
+    # there currently is a bug with returning the bbox, but isn't necessary
+    # to fix for creating the initial uNet for segmentation exploration
+    
+    # bbox = [bbox_x,bbox_y,bbox_height,bbox_width]
+    return(image,one_hot_seg)
+
+#############################################################
+
+class EncoderBlock(layers.Layer):
+    '''This function returns an encoder block with two convolutional layers and 
+       an option for returning both a max-pooled output with a stride and pool 
+       size of (2,2) and the output of the second convolution for skip 
+       connections implemented later in the network during the decoding 
+       section. All padding is set to "same" for cleanliness.
+       
+       When initializing it receives the number of filters to be used in both
+       of the convolutional layers as well as the kernel size and stride for 
+       those same layers. It also receives the trainable variable for use with
+       the batch normalization layers.'''
+
+    def __init__(self,
+                 filters,
+                 kernel_size=(3,3),
+                 strides=(1,1),
+                 trainable=True,
+                 name='encoder_block',
+                 **kwargs):
+
+        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
+        # When initializing this object receives a trainable parameter for
+        # freezing the convolutional layers. 
+
+        # including the image normalization within the network for easier image
+        # processing during inference
+        self.image_normalization = layers.Rescaling(scale=1./255)
+
+        # below creates the first of two convolutional layers
+        self.conv1 = layers.Conv2D(filters=filters,
+                      kernel_size=kernel_size,
+                      strides=strides,
+                      padding='same',
+                      name='encoder_conv1',
+                      trainable=trainable)
+
+        # second of two convolutional layers
+        self.conv2 = layers.Conv2D(filters=filters,
+                      kernel_size=kernel_size,
+                      strides=strides,
+                      padding='same',
+                      name='encoder_conv2',
+                      trainable=trainable)
+
+        # creates the max-pooling layer for downsampling the image.
+        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
+                                    strides=(2,2),
+                                    padding='same',
+                                    name='enc_pool')
+
+        # ReLU layer for activations.
+        self.ReLU = layers.ReLU()
+        
+        # both batch normalization layers for use with their corresponding
+        # convolutional layers.
+        self.batch_norm1 = tf.keras.layers.BatchNormalization()
+        self.batch_norm2 = tf.keras.layers.BatchNormalization()
+
+    def call(self,input,normalization=False,training=True,include_pool=True):
+        
+        # first conv of the encoder block
+        if normalization:
+            x = self.image_normalization(input)
+            x = self.conv1(x)
+        else:
+            x = self.conv1(input)
+
+        x = self.batch_norm1(x,training=training)
+        x = self.ReLU(x)
+
+        # second conv of the encoder block
+        x = self.conv2(x)
+        x = self.batch_norm2(x,training=training)
+        x = self.ReLU(x)
+        
+        # calculate and include the max pooling layer if include_pool is true.
+        # This output is used for the skip connections later in the network.
+        if include_pool:
+            pooled_x = self.enc_pool(x)
+            return(x,pooled_x)
+
+        else:
+            return(x)
+
+
+#############################################################
+
+class DecoderBlock(layers.Layer):
+    '''This function returns a decoder block that when called receives both an
+       input and a "skip connection". The input is passed to the 
+       "up convolution" or transpose conv layer to double the dimensions before
+       being concatenated with its associated skip connection from the encoder
+       section of the network. All padding is set to "same" for cleanliness. 
+       The decoder block also has an option for including an additional 
+       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
+       produces the logits for the one-hot encoded ground truth. 
+       
+       When initializing it receives the number of filters to be used in the
+       up convolutional layer as well as the other two forward convolutions. 
+       The received kernel_size and stride is used for the forward convolutions,
+       with the up convolution kernel and stride set to be (2,2).'''
+    def __init__(self,
+                 filters,
+                 trainable=True,
+                 kernel_size=(3,3),
+                 strides=(1,1),
+                 name='DecoderBlock',
+                 **kwargs):
+
+        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
+
+        # creating the up convolution layer
+        self.up_conv = layers.Conv2DTranspose(filters=filters,
+                                              kernel_size=(2,2),
+                                              strides=(2,2),
+                                              padding='same',
+                                              name='decoder_upconv',
+                                              trainable=trainable)
+
+        # the first of two forward convolutional layers
+        self.conv1 = layers.Conv2D(filters=filters,
+                                   kernel_size=kernel_size,
+                                   strides=strides,
+                                   padding='same',
+                                   name ='decoder_conv1',
+                                   trainable=trainable)
+
+        # second convolutional layer
+        self.conv2 = layers.Conv2D(filters=filters,
+                                   kernel_size=kernel_size,
+                                   strides=strides,
+                                   padding='same',
+                                   name ='decoder_conv2',
+                                   trainable=trainable)
+
+        # this creates the output prediction logits layer.
+        self.seg_out = layers.Conv2D(filters=7,
+                        kernel_size=(1,1),
+                        name='conv_feature_map')
+
+        # ReLU for activation of all above layers
+        self.ReLU = layers.ReLU()
+        
+        # the individual batch normalization layers for their respective 
+        # convolutional layers.
+        self.batch_norm1 = tf.keras.layers.BatchNormalization()
+        self.batch_norm2 = tf.keras.layers.BatchNormalization()
+
+
+    def call(self,input,skip_conn,training=True,segmentation=False):
+        
+        up = self.up_conv(input) # perform image up convolution
+        # concatenate the input and the skip_conn along the features axis
+        concatenated = layers.concatenate([up,skip_conn],axis=-1)
+
+        # first convolution 
+        x = self.conv1(concatenated)
+        x = self.batch_norm1(x,training=training)
+        x = self.ReLU(x)
+
+        # second convolution
+        x = self.conv2(x)
+        x = self.batch_norm2(x,training=training)
+        x = self.ReLU(x)
+
+        # if segmentation is True, then run the segmentation (1,1) convolution
+        # and use the Softmax to produce a probability distribution.
+        if segmentation:
+            seg = self.seg_out(x)
+            # deliberately set as "float32" to ensure proper calculation if 
+            # switching to mixed precision for efficiency
+            prob = layers.Softmax(dtype='float32')(seg)
+            return(prob)
+
+        else:
+            return(x)
+
+#############################################################
+
+class uNet(keras.Model):
+    '''This is a sub-classed model that uses the encoder and decoder blocks
+       defined above to create a custom unet. The differences from the original 
+       paper include a variable filter scalar (filter_multiplier), batch 
+       normalization between each convolutional layer and the associated ReLU 
+       activation, as well as feature normalization implemented in the first 
+       layer of the network.'''
+    def __init__(self,filter_multiplier=2,**kwargs):
+        super(uNet,self).__init__()
+        
+        # Defining encoder blocks
+        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
+                                           name='Enc1')
+        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
+                                           name='Enc2')
+        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
+                                           name='Enc3')
+        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
+                                           name='Enc4')
+        # self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
+        #                                    name='Enc5')
+
+        # Defining decoder blocks. The names are in reverse order to make it 
+        # (hopefully) easier to understand which skip connections are associated
+        # with which decoder layers.
+        # self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
+        #                                    name='Dec4')
+        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
+                                           name='Dec3')
+        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
+                                           name='Dec2')
+        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
+                                           name='Dec1')
+
+
+    def call(self,inputs,training):
+
+        # encoder    
+        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
+        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
+        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
+
+        # enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
+        # enc5 = self.encoder_block5(input=enc4_pool,
+        #                            include_pool=False,
+        #                            training=training)
+
+        enc4 = self.encoder_block4(input=enc3_pool,
+                                   include_pool=False,
+                                   training=training)
+
+
+        # decoder
+        # dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
+        dec3 = self.decoder_block3(input=enc4,skip_conn=enc3,training=training)
+        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
+        seg_logits_out = self.decoder_block1(input=dec2,
+                                             skip_conn=enc1,
+                                             segmentation=True,
+                                             training=training)
+
+        return(seg_logits_out)
+
+#############################################################
+
+def load_dataset(file_names):
+    '''Receives a list of file names from a folder that contains tfrecord files
+       compiled previously. Takes these names and creates a tensorflow dataset
+       from them.'''
+
+    ignore_order = tf.data.Options()
+    ignore_order.experimental_deterministic = False
+    dataset = tf.data.TFRecordDataset(file_names)
+
+    # you can shard the dataset if you like to reduce the size when necessary
+    dataset = dataset.shard(num_shards=8,index=2)
+    
+    # order in the file names doesn't really matter, so ignoring it
+    dataset = dataset.with_options(ignore_order)
+
+    # mapping the dataset using the parse_tf_elements function defined earlier
+    dataset = dataset.map(parse_tf_elements,num_parallel_calls=1)
+    
+    return(dataset)
+
+#############################################################
+
+def get_dataset(file_names,batch_size):
+    '''Receives a list of file names of tfrecord shards from a dataset as well
+       as a batch size for the dataset.'''
+    
+    # uses the load_dataset function to retrieve the files and put them into a 
+    # dataset.
+    dataset = load_dataset(file_names)
+    
+    # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free
+    # to alter as fits your hardware.
+    dataset = dataset.shuffle(300)
+
+    # adding the batch size to the dataset
+    dataset = dataset.batch(batch_size=batch_size)
+
+    return(dataset)
+
+#############################################################
+
+def weighted_cce_loss(y_true,y_pred):
+    '''Yes, this function essentially does what the "fit" argument 
+       "class_weight" does when training a network. I had to create this 
+       separate custom loss function because aparently when using tfrecord files
+       for reading your dataset a check is performed comparing the input, ground
+       truth, and weights values to each other. However, a comparison between 
+       the empty None that is passed during the build call of the model and the
+       weight array/dictionary returns an error. Thus, here is a custom loss 
+       function that applies a weighting to the different classes based on the 
+       distribution of the classes within the entire dataset. Note that the 
+       weights used here are only from the training set, not including images
+       from the testing and validation sets, to prevent any over-eager reviewers
+       from screaming "information leak!!"
+       Just kidding, it is first to prevent an information leak, and second to 
+       preempt over-eager reviewers.'''
+
+
+       
+
+    # weights for each class, as background, connective, muscle, and vasculature
+    # weights = [0, 2.95559004,   7.33779693,  12.87393959, 1000.43461107, 1200.63780628, 20.23600735]
+    # weights = [0, 0.80284233, 1.68275694, 2.63726432, 3000.8055788, 2000.26933614, 100.30741485] # last good run
+    # [0,2.72403952, 2.81034368, 4.36437716, 36.66264202, 108.40694198, 87.39903838]
+    weights = [0,2.72403952, 2.81034368, 4.36437716, 36.66264202, 108.40694198, 87.39903838]
+
+    count = 0
+
+
+    all_weights_for_loss = tf.expand_dims(tf.ones((1024,1024)).astype(tf.float64), axis=0)
+
+    for image in y_true:
+        weights_for_image = tf.ones((1024,1024)).astype(tf.float64)
+
+        for idx,weight in enumerate(weights):
+            mask = image[:,:,idx]
+            mask.set_shape((1024,1024))
+            indexes = tf.where(mask)
+            values_mask = mask*weights[idx]
+
+            values_updates = tf.boolean_mask(values_mask,mask).astype(tf.double)
+
+            weights_for_image = tf.tensor_scatter_nd_update(weights_for_image,indexes,values_updates)
+
+        if count == 0:
+            all_weights_for_loss = tf.expand_dims(weights_for_image, axis=0)
+        else:
+            all_weights_for_loss = tf.concat([all_weights_for_loss,tf.expand_dims(weights_for_image, axis=0)],axis=0)
+        count += 1
+
+    cce = tf.keras.losses.CategoricalCrossentropy()
+    cce_loss = cce(y_true,y_pred,all_weights_for_loss)
+
+    return(cce_loss)
+
+
+
+#############################################################
+#############################################################
+# %% Setting up the GPU, and setting memory growth to true so that it is easier
+# to see how much memory the training process is taking up exactly. This code is
+# from a tensorflow tutorial. 
+
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+  try:
+    for gpu in gpus:
+      tf.config.experimental.set_memory_growth(gpu, True)
+    logical_gpus = tf.config.list_logical_devices('GPU')
+
+    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
+  except RuntimeError as e:
+    print(e)
+
+# use this to set mixed precision for higher efficiency later if you would like
+# mixed_precision.set_global_policy('mixed_float16')
+
+# %% setting up datasets and building model
+
+# directory where the dataset shards are stored
+os.chdir('/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/')
+training_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/train'
+val_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/validate'
+testing_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/test'
+
+# only get the file names that follow the shard naming convention
+train_files = tf.io.gfile.glob(training_directory + \
+                              "/shard_*_of_*.tfrecords")
+val_files = tf.io.gfile.glob(val_directory + \
+                              "/shard_*_of_*.tfrecords")
+test_files = tf.io.gfile.glob(testing_directory + \
+                              "/shard_*_of_*.tfrecords")
+
+# create the datasets. Because of how batches are run for training, we set
+# the dataset to repeat() because the batches and epochs are altered from 
+# standard practice to fit on graphics cards and provide more meaningful and 
+# frequent updates to the console.
+training_dataset = get_dataset(train_files,batch_size=1)
+training_dataset = training_dataset.repeat()
+validation_dataset = get_dataset(val_files,batch_size = 1)
+# testing has a batch size of 1 to facilitate visualization of predictions
+testing_dataset = get_dataset(test_files,batch_size=1)
+
+# explicitly puts the model on the GPU to show how large it is. 
+gpus = tf.config.list_logical_devices('GPU')
+with tf.device(gpus[0].name):
+    # filter multiplier provided creates largest filter depth of 256 with a 
+    # multiplier of 8. 
+    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
+    unet = uNet(filter_multiplier=32,)
+    # build with input image size of 512*512
+    out = unet(sample_data)
+    unet.summary()
+# %%
+# running network eagerly because it allows us to use convert a tensor to a
+# numpy array to help with the weighted loss calculation.
+unet.compile(
+    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
+    loss=tf.keras.losses.CategoricalCrossentropy(),
+    run_eagerly=True,
+    metrics=[tf.keras.metrics.Precision(name='precision'),
+                tf.keras.metrics.Recall(name='recall')]
+)
+
+# %%
+class SanityCheck(keras.callbacks.Callback):
+
+    def __init__(self, testing_images):
+        super(SanityCheck, self).__init__()
+        self.testing_images = testing_images
+
+
+    def on_epoch_end(self,epoch, logs=None):
+        for image_pair in self.testing_images:
+            out = self.model.predict(image_pair[0],verbose=0)
+            image = cv.cvtColor(np.squeeze(np.asarray(image_pair[0]).copy()),cv.COLOR_BGR2RGB)
+            squeezed_gt = tf.argmax(image_pair[1],axis=-1)
+            squeezed_prediction = tf.argmax(out,axis=-1)
+
+            vasc_gt = np.squeeze(image_pair[1][0,:,:,4])
+            neural_gt = np.squeeze(image_pair[1][0,:,:,5])
+            vasc_pred = np.squeeze(out[0,:,:,4])
+            neural_pred = np.squeeze(out[0,:,:,5])
+
+            fig,ax = plt.subplots(1,3)
+
+            ax[0].imshow(image)
+            ax[1].imshow(squeezed_gt[0,:,:],vmin=0, vmax=7)
+            ax[2].imshow(squeezed_prediction[0,:,:],vmin=0, vmax=7)
+            # ax[1].imshow(squeezed_gt[0,:,:]==4)
+            # ax[2].imshow(squeezed_prediction[0,:,:]==4)
+            plt.show()
+            print(np.unique(squeezed_gt[0,:,:]))
+            print(np.unique(squeezed_prediction[0,:,:]))
+
+
+test_images = []
+for sample in testing_dataset.take(5):
+    #print(sample[0].shape)
+    test_images.append([sample[0],sample[1]])
+    
+# %%
+
+# creating callbacks
+sanity_check = SanityCheck(test_images)
+
+def schedule(epoch, lr):
+    if (epoch % 3) == 0:
+        return(lr*0.7)
+    else:
+        return(lr)
+
+lr_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
+
+reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
+                                                 mode='min',
+                                                 factor=0.8,
+                                                 patience=5,
+                                                 min_lr=0.000001,
+                                                 verbose=True,
+                                                 min_delta=0.01,)
+
+checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('unet_seg_weights.{epoch:02d}-{val_loss:.2f}-{val_precision:.2f}-{val_recall:.2f}.h5',
+                                                   save_weights_only=True,
+                                                   monitor='loss',
+                                                   mode='min',
+                                                   verbose=True)
+
+early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,
+                                                     monitor='loss',
+                                                     mode='min',
+                                                     restore_best_weights=True,
+                                                     verbose=True,
+                                                     min_delta=0.01)
+
+# setting the number of batches to iterate through each epoch to a value much
+# lower than what it normaly would be so that we can actually see what is going
+# on with the network, as well as have a meaningful early stopping.
+
+
+# %% fit the network!
+# unet.load_weights('./unet_seg_weights.50-0.64-0.93-0.91.h5')
+num_steps = 100
+
+weights = {0:0,1:2.72403952,2:2.81034368,3:4.36437716,4:36.66264202, 5:108.40694198, 6:87.39903838}
+
+history = unet.fit(training_dataset,
+                   epochs=100,
+                   steps_per_epoch=num_steps,
+                   validation_data=validation_dataset,
+                   class_weight=weights,
+                   callbacks=[checkpoint_cb,
+                              early_stopping_cb,
+                              reduce_lr,
+                              sanity_check,])
+# %%
+
+
+
+# %%
+# evaluate the network after loading the weights
+# unet.load_weights('./unet_seg_weights.49-0.52-0.94-0.92.h5')
+results = unet.evaluate(testing_dataset)
+print(results)
+# %%
+# extracting loss vs epoch
+loss = history.history['loss']
+val_loss = history.history['val_loss']
+# extracting precision vs epoch
+precision = history.history['precision']
+val_precision = history.history['val_precision']
+# extracting recall vs epoch
+recall = history.history['recall']
+val_recall = history.history['val_recall']
+
+epochs = range(len(loss))
+
+figs, axes = plt.subplots(3,1)
+
+# plotting loss and validation loss
+axes[0].plot(epochs,loss)
+axes[0].plot(epochs,val_loss)
+axes[0].legend(['loss','val_loss'])
+axes[0].set(xlabel='epochs',ylabel='crossentropy loss')
+
+# plotting precision and validation precision
+axes[1].plot(epochs,precision)
+axes[1].plot(epochs,val_precision)
+axes[1].legend(['precision','val_precision'])
+axes[1].set(xlabel='epochs',ylabel='precision')
+
+# plotting recall validation recall
+axes[2].plot(epochs,recall)
+axes[2].plot(epochs,val_recall)
+axes[2].legend(['recall','val_recall'])
+axes[2].set(xlabel='epochs',ylabel='recall')
+
+
+
+# %% exploring the predictions to better understand what the network is doing
+
+images = []
+gt = []
+predictions = []
+
+# taking out 10 of the next samples from the testing dataset and iterating 
+# through them
+for sample in testing_dataset.take(10):
+    # make sure it is producing the correct dimensions
+    print(sample[0].shape)
+    # take the image and convert it back to RGB, store in list
+    image = sample[0]
+    image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB)
+    images.append(image)
+    # extract the ground truth and store in list
+    ground_truth = sample[1]
+    gt.append(ground_truth)
+    # perform inference
+    out = unet.predict(sample[0])
+    predictions.append(out)
+    # show the original input image
+    plt.imshow(image)
+    plt.show()
+    # flatten the ground truth from one-hot encoded along the last axis, and 
+    # show the resulting image
+    squeezed_gt = tf.argmax(ground_truth,axis=-1)
+    squeezed_prediction = tf.argmax(out,axis=-1)
+    plt.imshow(squeezed_gt[0,:,:],vmin=0, vmax=6)
+    # print the number of classes in this tile
+    print(np.unique(squeezed_gt))
+    plt.show()
+    # show the flattened predictions
+    plt.imshow(squeezed_prediction[0,:,:],vmin=0, vmax=6)
+    print(np.unique(squeezed_prediction))
+    plt.show()
+
+# %%
+# select one of the images cycled through above to investigate further
+image_to_investigate = 6
+
+# show the original image
+plt.imshow(images[image_to_investigate])
+plt.show()
+
+# show the ground truth for this tile
+squeezed_gt = tf.argmax(gt[image_to_investigate],axis=-1)
+plt.imshow(squeezed_gt[0,:,:])
+# print the number of unique classes in the ground truth
+print(np.unique(squeezed_gt))
+plt.show()
+ # flatten the prediction and show the probability distribution
+squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
+plt.imshow(predictions[image_to_investigate][0,:,:,3])
+plt.show()
+# show the flattened image
+plt.imshow(squeezed_prediction[0,:,:])
+print(np.unique(squeezed_prediction))
+plt.show()
+
+# %%