Diff of /uNet_Subclassed_SCCE.py [000000] .. [3b7fea]

Switch to side-by-side view

--- a
+++ b/uNet_Subclassed_SCCE.py
@@ -0,0 +1,845 @@
+# %% importing packages
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras import mixed_precision
+from tensorflow.python.ops.numpy_ops import np_config
+np_config.enable_numpy_behavior()
+from skimage import measure
+import cv2 as cv
+import os
+import matplotlib.pyplot as plt
+plt.rcParams['figure.figsize'] = [5, 5]
+# you can alternatively call this script using this line in the terminal to
+# address the issue of memory leak when using the dataset.shuffle buffer. Found
+# at the subsequent link.
+# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4.5.9 python3 uNet_Subclassed.py
+
+# https://stackoverflow.com/questions/55211315/memory-leak-with-tf-data/66971031#66971031
+
+
+# %% Citations
+#############################################################
+#############################################################
+# https://www.tensorflow.org/guide/keras/functional
+# https://www.tensorflow.org/tutorials/customization/custom_layers
+# https://keras.io/examples/keras_recipes/tfrecord/
+# https://arxiv.org/abs/1505.04597
+# https://www.tensorflow.org/guide/gpu
+
+# Defining Functions
+#############################################################
+#############################################################
+
+def parse_tf_elements(element):
+    '''This function is the mapper function for retrieving examples from the
+       tfrecord'''
+
+    # create placeholders for all the features in each example
+    data = {
+        'height' : tf.io.FixedLenFeature([],tf.int64),
+        'width' : tf.io.FixedLenFeature([],tf.int64),
+        'raw_image' : tf.io.FixedLenFeature([],tf.string),
+        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
+        'bbox_x' : tf.io.VarLenFeature(tf.float32),
+        'bbox_y' : tf.io.VarLenFeature(tf.float32),
+        'bbox_height' : tf.io.VarLenFeature(tf.float32),
+        'bbox_width' : tf.io.VarLenFeature(tf.float32),
+        'name' : tf.io.FixedLenFeature([],tf.string),
+    }
+
+    # pull out the current example
+    content = tf.io.parse_single_example(element, data)
+
+    # pull out each feature from the example 
+    height = content['height']
+    width = content['width']
+    raw_seg = content['raw_seg']
+    raw_image = content['raw_image']
+    name = content['name']
+
+    # note that the bounding boxes are included here, but are not used. These 
+    # were included in the dataset for future use if I wanted to put together
+    # something like YOLO for practice. Could be used later, but also haven't 
+    # been thoroughly tested, so could be buggy and should be vetted.
+    bbox_x = content['bbox_x']
+    bbox_y = content['bbox_y']
+    bbox_height = content['bbox_height']
+    bbox_width = content['bbox_width']
+
+    # convert the images to uint8, and reshape them accordingly
+    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
+    image = tf.reshape(image,shape=[height,width,3])
+    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)-1
+    # This is including the class weights in the parser, enabling them to be
+    # used by the loss function to weight the loss and accuracy metrics.
+    # Note that the last two are divided by two to prevent them from being over
+    # segmented, which they were.
+    # [2.72403952, 2.81034368, 4.36437716, 36.66264202, 108.40694198, 87.39903838]
+    weights = [2.15248481,
+               3.28798466, 
+               5.18559616, 
+               46.96594578*3, 
+               130.77512742*2, 
+               105.23678672/2]
+    weights = np.divide(weights,sum(weights))
+    
+    # the weights are calculated by the tf_record_weight_determination.py file,
+    # and are related to the percentages of each class in the dataset.
+    sample_weights = tf.gather(weights, indices=tf.cast(segmentation, tf.int32))
+
+    return(image,segmentation,sample_weights)
+
+#############################################################
+
+class EncoderBlock(layers.Layer):
+    '''This function returns an encoder block with two convolutional layers and 
+       an option for returning both a max-pooled output with a stride and pool 
+       size of (2,2) and the output of the second convolution for skip 
+       connections implemented later in the network during the decoding 
+       section. All padding is set to "same" for cleanliness.
+       
+       When initializing it receives the number of filters to be used in both
+       of the convolutional layers as well as the kernel size and stride for 
+       those same layers. It also receives the trainable variable for use with
+       the batch normalization layers.'''
+
+    def __init__(self,
+                 filters,
+                 kernel_size=(3,3),
+                 strides=(1,1),
+                 trainable=True,
+                 name='encoder_block',
+                 **kwargs):
+
+        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
+        # When initializing this object receives a trainable parameter for
+        # freezing the convolutional layers. 
+
+        # including the image normalization within the network for easier image
+        # processing during inference
+        self.image_normalization = layers.Rescaling(scale=1./255)
+
+        # below creates the first of two convolutional layers
+        self.conv1 = layers.Conv2D(filters=filters,
+                      kernel_size=kernel_size,
+                      strides=strides,
+                      padding='same',
+                      name='encoder_conv1',
+                      trainable=trainable)
+
+        # second of two convolutional layers
+        self.conv2 = layers.Conv2D(filters=filters,
+                      kernel_size=kernel_size,
+                      strides=strides,
+                      padding='same',
+                      name='encoder_conv2',
+                      trainable=trainable)
+
+        # creates the max-pooling layer for downsampling the image.
+        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
+                                    strides=(2,2),
+                                    padding='same',
+                                    name='enc_pool')
+
+        # ReLU layer for activations.
+        self.ReLU = layers.ReLU()
+        
+        # both batch normalization layers for use with their corresponding
+        # convolutional layers.
+        self.batch_norm1 = tf.keras.layers.BatchNormalization()
+        self.batch_norm2 = tf.keras.layers.BatchNormalization()
+
+    def call(self,input,normalization=False,training=True,include_pool=True):
+        
+        # first conv of the encoder block
+        if normalization:
+            x = self.image_normalization(input)
+            x = self.conv1(x)
+        else:
+            x = self.conv1(input)
+
+        x = self.batch_norm1(x,training=training)
+        x = self.ReLU(x)
+
+        # second conv of the encoder block
+        x = self.conv2(x)
+        x = self.batch_norm2(x,training=training)
+        x = self.ReLU(x)
+        
+        # calculate and include the max pooling layer if include_pool is true.
+        # This output is used for the skip connections later in the network.
+        if include_pool:
+            pooled_x = self.enc_pool(x)
+            return(x,pooled_x)
+
+        else:
+            return(x)
+
+
+#############################################################
+
+class DecoderBlock(layers.Layer):
+    '''This function returns a decoder block that when called receives both an
+       input and a "skip connection". The input is passed to the 
+       "up convolution" or transpose conv layer to double the dimensions before
+       being concatenated with its associated skip connection from the encoder
+       section of the network. All padding is set to "same" for cleanliness. 
+       The decoder block also has an option for including an additional 
+       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
+       produces the logits for the one-hot encoded ground truth. 
+       
+       When initializing it receives the number of filters to be used in the
+       up convolutional layer as well as the other two forward convolutions. 
+       The received kernel_size and stride is used for the forward convolutions,
+       with the up convolution kernel and stride set to be (2,2).'''
+    def __init__(self,
+                 filters,
+                 trainable=True,
+                 kernel_size=(3,3),
+                 strides=(1,1),
+                 name='DecoderBlock',
+                 **kwargs):
+
+        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
+
+        # creating the up convolution layer
+        self.up_conv = layers.Conv2DTranspose(filters=filters,
+                                              kernel_size=(2,2),
+                                              strides=(2,2),
+                                              padding='same',
+                                              name='decoder_upconv',
+                                              trainable=trainable)
+
+        # the first of two forward convolutional layers
+        self.conv1 = layers.Conv2D(filters=filters,
+                                   kernel_size=kernel_size,
+                                   strides=strides,
+                                   padding='same',
+                                   name ='decoder_conv1',
+                                   trainable=trainable)
+
+        # second convolutional layer
+        self.conv2 = layers.Conv2D(filters=filters,
+                                   kernel_size=kernel_size,
+                                   strides=strides,
+                                   padding='same',
+                                   name ='decoder_conv2',
+                                   trainable=trainable)
+
+        # this creates the output prediction logits layer.
+        self.seg_out = layers.Conv2D(filters=6,
+                        kernel_size=(1,1),
+                        name='conv_feature_map')
+
+        # ReLU for activation of all above layers
+        self.ReLU = layers.ReLU()
+        
+        # the individual batch normalization layers for their respective 
+        # convolutional layers.
+        self.batch_norm1 = tf.keras.layers.BatchNormalization()
+        self.batch_norm2 = tf.keras.layers.BatchNormalization()
+
+
+    def call(self,input,skip_conn,training=True,segmentation=False,prob_dist=True):
+        
+        up = self.up_conv(input) # perform image up convolution
+        # concatenate the input and the skip_conn along the features axis
+        concatenated = layers.concatenate([up,skip_conn],axis=-1)
+
+        # first convolution 
+        x = self.conv1(concatenated)
+        x = self.batch_norm1(x,training=training)
+        x = self.ReLU(x)
+
+        # second convolution
+        x = self.conv2(x)
+        x = self.batch_norm2(x,training=training)
+        x = self.ReLU(x)
+
+        # if segmentation is True, then run the segmentation (1,1) convolution
+        # and use the Softmax to produce a probability distribution.
+        if segmentation:
+            seg = self.seg_out(x)
+            # deliberately set as "float32" to ensure proper calculation if 
+            # switching to mixed precision for efficiency
+            if prob_dist:
+                seg = layers.Softmax(dtype='float32')(seg)
+
+            return(seg)
+
+        else:
+            return(x)
+
+#############################################################
+
+class uNet(keras.Model):
+    '''This is a sub-classed model that uses the encoder and decoder blocks
+       defined above to create a custom unet. The differences from the original 
+       paper include a variable filter scalar (filter_multiplier), batch 
+       normalization between each convolutional layer and the associated ReLU 
+       activation, as well as feature normalization implemented in the first 
+       layer of the network.'''
+    def __init__(self,filter_multiplier=2,**kwargs):
+        super(uNet,self).__init__()
+        
+        # Defining encoder blocks
+        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
+                                           name='Enc1')
+        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
+                                           name='Enc2')
+        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
+                                           name='Enc3')
+        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
+                                           name='Enc4')
+        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
+                                           name='Enc5')
+
+        # Defining decoder blocks. The names are in reverse order to make it 
+        # (hopefully) easier to understand which skip connections are associated
+        # with which decoder layers.
+        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
+                                           name='Dec4')
+        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
+                                           name='Dec3')
+        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
+                                           name='Dec2')
+        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
+                                           name='Dec1')
+
+
+    def call(self,inputs,training,predict=False,threshold=3):
+
+        # encoder    
+        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
+        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
+        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
+        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
+        enc5 = self.encoder_block5(input=enc4_pool,
+                                   include_pool=False,
+                                   training=training)
+
+        # enc4 = self.encoder_block4(input=enc3_pool,
+        #                            include_pool=False,
+        #                            training=training)
+
+
+        # decoder
+        dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
+        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
+        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
+        prob_dist_out = self.decoder_block1(input=dec2,
+                                            skip_conn=enc1,
+                                            segmentation=True,
+                                            training=training)
+        if predict:
+            seg_logits_out = self.decoder_block1(input=dec2,
+                                                 skip_conn=enc1,
+                                                 segmentation=True,
+                                                 training=training,
+                                                 prob_dist=False)
+
+        # This prediction is included to allow one to seta threshold for the 
+        # uncertainty, deemed an arbitrary value that corresponds to the 
+        # maximum value of the logits predicted at a specific point in the 
+        # image. It only includes predictions for the vascular and neural 
+        # tissues if they are above the confidence threshold, if they are below
+        # the threshold the predictions are defaulted to muscle, connective,
+        # or background.
+        
+        if predict:
+            # rename the value for consistency and write protection.
+            y_pred = seg_logits_out
+            pred_shape = (1,1024,1024,6)
+            # Getting an image-sized preliminary segmentation prediction
+            squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
+
+            # initializing the variable used for storing the maximum logits at 
+            # each pixel location.
+            max_value_predictions = tf.zeros((1024,1024))
+
+            # cycle through all the classes 
+            for idx in range(6):
+                
+                # current class logits
+                current_slice = tf.squeeze(y_pred[:,:,:,idx])
+                # find the locations where this class is predicted
+                current_indices = squeezed_prediction == idx
+                # define the shape so that this function can run in graph mode
+                # and not need eager execution.
+                current_indices.set_shape((1024,1024))
+                # Get the indices of where the idx class is predicted
+                indices = tf.where(squeezed_prediction == idx)
+                # get the output of boolean_mask to enable scatter update of the
+                # tensor. This is required because tensors do not support 
+                # mask indexing.
+                values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
+                # Place the maximum logit values at each point in an 
+                # image-size matrix, indicating the confidence in the prediction
+                # at each pixel. 
+                max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
+            
+            for idx in [3,4]:
+                mask_list = []
+                for idx2 in range(6):
+                    if idx2 == idx:
+                        mid_mask = max_value_predictions<threshold
+                        mask_list.append(mid_mask.astype(tf.float32))
+                    else:
+                        mask_list.append(tf.zeros((1024,1024)))
+
+                mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
+
+                indexes = tf.where(mask)
+                values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
+
+                seg_logits_out = tf.tensor_scatter_nd_update(seg_logits_out,indexes,values_updates.astype(tf.float32))
+                prob_dist_out = layers.Softmax(dtype='float32')(seg_logits_out)
+            # print("updated logits!")
+
+
+            
+        return(prob_dist_out)
+
+
+    # def test_step(self, data):
+        
+    #     threshold = 3
+    #     x, y, weight = data
+    #     pred_shape = (1,1024,1024,6)
+
+    #     y_pred = self(x,training=False)
+
+    #     squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
+
+    #     max_value_predictions = tf.zeros((1024,1024))
+
+    #     for idx in range(6):
+
+    #         current_slice = tf.squeeze(y_pred[:,:,:,idx])
+    #         current_indices = squeezed_prediction == idx
+    #         current_indices.set_shape((1024,1024))
+    #         indices = tf.where(squeezed_prediction == idx)
+    #         values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
+    #         max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
+        
+    #     for idx in [3,4]:
+    #         mask_list = []
+    #         for idx2 in range(6):
+    #             if idx2 == idx:
+    #                 mid_mask = max_value_predictions<threshold
+    #                 mask_list.append(mid_mask.astype(tf.float32))
+    #             else:
+    #                 mask_list.append(tf.zeros((1024,1024)))
+
+    #         mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
+
+    #         indexes = tf.where(mask)
+    #         values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
+
+    #         y_pred = tf.tensor_scatter_nd_update(y_pred,indexes,values_updates.astype(tf.float32))
+
+    #     self.compiled_metrics.update_state(y, y_pred, sample_weight=weight)
+    #     self.compiled_loss(y, y_pred, sample_weight=weight)
+
+    #     return {m.name: m.result() for m in self.metrics}
+
+#############################################################
+
+class SanityCheck(keras.callbacks.Callback):
+
+    def __init__(self, testing_images):
+        super(SanityCheck, self).__init__()
+        self.testing_images = testing_images
+
+
+    def on_epoch_end(self,epoch, logs=None):
+        for image_pair in self.testing_images:
+            out = self.model.predict(image_pair[0],verbose=0)
+            image = cv.cvtColor(np.squeeze(np.asarray(image_pair[0]).copy()),cv.COLOR_BGR2RGB)
+            squeezed_gt = image_pair[1][0,:,:]
+            squeezed_prediction = tf.argmax(out,axis=-1)
+
+            fig,ax = plt.subplots(1,3)
+
+            ax[0].imshow(image)
+            ax[1].imshow(squeezed_gt,vmin=0, vmax=5)
+            ax[2].imshow(squeezed_prediction[0,:,:],vmin=0, vmax=5)
+
+            plt.show()
+            print(np.unique(squeezed_gt))
+            print(np.unique(squeezed_prediction[0,:,:]))
+
+
+#############################################################
+
+def load_dataset(file_names):
+    '''Receives a list of file names from a folder that contains tfrecord files
+       compiled previously. Takes these names and creates a tensorflow dataset
+       from them.'''
+
+    ignore_order = tf.data.Options()
+    ignore_order.experimental_deterministic = False
+    dataset = tf.data.TFRecordDataset(file_names)
+
+    # you can shard the dataset if you like to reduce the size when necessary
+    dataset = dataset.shard(num_shards=8,index=2)
+    
+    # order in the file names doesn't really matter, so ignoring it
+    dataset = dataset.with_options(ignore_order)
+
+    # mapping the dataset using the parse_tf_elements function defined earlier
+    dataset = dataset.map(parse_tf_elements,num_parallel_calls=1)
+    
+    return(dataset)
+
+#############################################################
+
+def get_dataset(file_names,batch_size):
+    '''Receives a list of file names of tfrecord shards from a dataset as well
+       as a batch size for the dataset.'''
+    
+    # uses the load_dataset function to retrieve the files and put them into a 
+    # dataset.
+    dataset = load_dataset(file_names)
+    
+    # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free
+    # to alter as fits your hardware.
+    dataset = dataset.shuffle(300)
+
+    # adding the batch size to the dataset
+    dataset = dataset.batch(batch_size=batch_size)
+
+    return(dataset)
+
+
+#############################################################
+#############################################################
+# %% Setting up the GPU, and setting memory growth to true so that it is easier
+# to see how much memory the training process is taking up exactly. This code is
+# from a tensorflow tutorial. 
+
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+  try:
+    for gpu in gpus:
+      tf.config.experimental.set_memory_growth(gpu, True)
+    logical_gpus = tf.config.list_logical_devices('GPU')
+
+    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
+  except RuntimeError as e:
+    print(e)
+
+# use this to set mixed precision for higher efficiency later if you would like
+# mixed_precision.set_global_policy('mixed_float16')
+
+# %% setting up datasets and building model
+
+# directory where the dataset shards are stored
+home_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_6'
+training_directory = home_directory + '/train'
+val_directory = home_directory + '/validate'
+testing_directory = home_directory + '/test'
+
+os.chdir(home_directory)
+
+# only get the file names that follow the shard naming convention
+train_files = tf.io.gfile.glob(training_directory + \
+                              "/shard_*_of_*.tfrecords")
+val_files = tf.io.gfile.glob(val_directory + \
+                              "/shard_*_of_*.tfrecords")
+test_files = tf.io.gfile.glob(testing_directory + \
+                              "/shard_*_of_*.tfrecords")
+
+# create the datasets. Because of how batches are run for training, we set
+# the dataset to repeat() because the batches and epochs are altered from 
+# standard practice to fit on graphics cards and provide more meaningful and 
+# frequent updates to the console.
+training_dataset = get_dataset(train_files,batch_size=3)
+training_dataset = training_dataset.repeat()
+validation_dataset = get_dataset(val_files,batch_size = 3)
+# testing has a batch size of 1 to facilitate visualization of predictions
+testing_dataset = get_dataset(test_files,batch_size=1)
+
+# explicitly puts the model on the GPU to show how large it is. 
+gpus = tf.config.list_logical_devices('GPU')
+with tf.device(gpus[0].name):
+    # filter multiplier provided creates largest filter depth of 256 with a 
+    # multiplier of 8. 
+    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
+    unet = uNet(filter_multiplier=12,) # 12 is the magic number
+    # build with input image size of 512*512
+    out = unet(sample_data)
+    unet.summary()
+# %%
+
+unet.compile(
+    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
+    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
+    run_eagerly=False,
+    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()]
+)
+
+test_images = []
+for sample in testing_dataset.take(5):
+    #print(sample[0].shape)
+    test_images.append([sample[0],sample[1]])
+
+sanity_check = SanityCheck(test_images)
+
+
+def schedule(epoch, lr): 
+        return(lr*0.97)
+
+lr_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule, verbose=1)
+
+
+reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
+                                                 mode='min',
+                                                 factor=0.8,
+                                                 patience=5,
+                                                 min_lr=0.000001,
+                                                 verbose=True,
+                                                 min_delta=0.01,)
+
+checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
+    'unet_seg_weights.{epoch:02d}-{val_sparse_categorical_accuracy:.4f}-{val_loss:.4f}.h5',
+    save_weights_only=True,
+    monitor='val_sparse_categorical_accuracy',
+    mode='max',
+    verbose=True
+    )
+
+early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,
+                                                     monitor='val_sparse_categorical_accuracy',
+                                                     mode='max',
+                                                     restore_best_weights=True,
+                                                     verbose=True,
+                                                     min_delta=0.001)
+
+# setting the number of batches to iterate through each epoch to a value much
+# lower than what it normaly would be so that we can actually see what is going
+# on with the network, as well as have a meaningful early stopping.
+
+
+# %% fit the network!
+num_steps = 600
+
+history = unet.fit(training_dataset,
+                   epochs=100,
+                   steps_per_epoch=num_steps,
+                   validation_data=validation_dataset,
+                   verbose=2,
+                   callbacks=[checkpoint_cb,
+                              early_stopping_cb,
+                              lr_scheduler,])
+# %%
+
+
+
+# %%
+# evaluate the network after loading the weights
+unet.load_weights('unet_seg_weights.84-0.9163-0.0053.h5')
+results = unet.evaluate(testing_dataset)
+print(results)
+# %%
+# extracting loss vs epoch
+loss = history.history['loss']
+val_loss = history.history['val_loss']
+acc = history.history['sparse_categorical_accuracy']
+val_acc = history.history['val_sparse_categorical_accuracy']
+
+# extracting precision vs epoch
+
+epochs = range(len(loss))
+
+figs, axes = plt.subplots(2,1)
+
+# plotting loss and validation loss
+axes[0].plot(epochs[1:],loss[1:])
+axes[0].plot(epochs[1:],val_loss[1:])
+axes[0].legend(['loss','val_loss'])
+axes[0].set(xlabel='epochs',ylabel='crossentropy loss')
+
+# plotting loss and validation loss
+axes[1].plot(epochs[1:],acc[1:])
+axes[1].plot(epochs[1:],val_acc[1:])
+axes[1].legend(['acc','val_acc'])
+axes[1].set(xlabel='epochs',ylabel='weighted accuracy')
+
+
+# %% exploring the predictions to better understand what the network is doing. 
+# This section is largely experimental, and should be treated as such. I have
+# included it in this network file for the sake of documentation and 
+# traceability, but it is not in the other network files for full image 
+# segmentation and directory segmentation because, well, those are functional 
+# and this is experimental.
+
+
+# uncomment everything from here down to use this section
+images = []
+gt = []
+predictions = []
+# higher threshold means the network must be more confident.
+threshold = 3
+
+# taking out 15 of the next samples from the testing dataset and iterating 
+# through them
+for sample in testing_dataset.take(15):
+    # make sure it is producing the correct dimensions
+    print(sample[0].shape)
+    # take the image and convert it back to RGB, store in list
+    image = sample[0]
+    image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB)
+    images.append(image)
+    # extract the ground truth and store in list
+    ground_truth = sample[1]
+    gt.append(ground_truth)
+    # perform inference
+    out = unet(sample[0],predict=True,threshold=threshold)
+    predictions.append(out)
+    # show the original input image
+    plt.imshow(image)
+    plt.show()
+    # flatten the ground truth from one-hot encoded along the last axis, and 
+    # show the resulting image
+    squeezed_gt = ground_truth
+    squeezed_prediction = tf.argmax(out,axis=-1)
+    plt.imshow(squeezed_gt[0,:,:],vmin=0, vmax=5)
+    # print the number of classes in this tile
+    print(np.unique(squeezed_gt))
+    plt.show()
+    # show the flattened predictions
+    plt.imshow(squeezed_prediction[0,:,:],vmin=0, vmax=5)
+    print(np.unique(squeezed_prediction))
+    plt.show()
+
+# # %% 5, 6, 8
+# # select one of the images cycled through above to investigate further
+# image_to_investigate = 0
+# threshold = 2
+# # show the original image
+# plt.imshow(images[image_to_investigate])
+# plt.show()
+
+# # show the ground truth for this tile
+# squeezed_gt = gt[image_to_investigate]
+# plt.imshow(squeezed_gt[0,:,:])
+# # print the number of unique classes in the ground truth
+# print(np.unique(squeezed_gt))
+# plt.show()
+#  # flatten the prediction and show the probability distribution
+
+# out = predictions[image_to_investigate]
+
+
+# # plt.hist(out[:,:,:,4].reshape(-1),alpha=0.5,label='neural')
+# # plt.hist(out[:,:,:,3].reshape(-1),alpha=0.5,label='vascular')
+# # plt.legend(["neural",'vascular'])
+
+# out = predictions[image_to_investigate]
+# squeezed_prediction = np.squeeze(tf.argmax(out,axis=-1))
+
+# max_value_predictions = np.zeros(squeezed_prediction.shape)
+
+# for idx in range(6):
+#     current_slice = np.squeeze(out[:,:,:,idx])
+#     current_indices = squeezed_prediction == idx
+#     indices = tf.where(squeezed_prediction == idx)
+#     values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
+#     max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
+
+# plt.imshow(max_value_predictions)
+# plt.show()
+
+# for idx in [3,4]:
+#     mask = np.zeros(out.shape)
+#     mask[:,:,:,idx] = max_value_predictions<threshold
+#     indices = tf.where(mask)
+#     values_updates = tf.boolean_mask(np.zeros(out.shape),mask).astype(tf.double)
+
+#     out = tf.tensor_scatter_nd_update(out,indices,values_updates.astype(tf.float32))
+
+# for idx in range(6):
+#     current_slice = np.squeeze(out[:,:,:,idx])
+#     current_indices = squeezed_prediction == idx
+#     indices = tf.where(squeezed_prediction == idx)
+#     values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
+#     max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
+# plt.imshow(max_value_predictions)
+# plt.show()
+
+
+# squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
+# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
+# # plt.show()
+# # show the flattened image
+# plt.imshow(squeezed_prediction[0,:,:])
+# print(np.unique(squeezed_prediction))
+# plt.show()
+
+# squeezed_prediction = tf.argmax(out,axis=-1)
+# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
+# # plt.show()
+# # show the flattened image
+# plt.imshow(squeezed_prediction[0,:,:])
+# print(np.unique(squeezed_prediction))
+# plt.show()
+
+# # %%
+# image_to_investigate = 0
+# threshold = 1
+# y_pred = predictions[image_to_investigate]
+
+
+# pred_shape = (1,1024,1024,6)
+
+# squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
+
+# max_value_predictions = tf.zeros((1024,1024))
+
+# for idx in range(6):
+
+#     current_slice = tf.squeeze(y_pred[:,:,:,idx])
+#     current_indices = squeezed_prediction == idx
+#     current_indices.set_shape((1024,1024))
+#     indices = tf.where(squeezed_prediction == idx)
+#     values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
+#     max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
+
+# for idx in [3,4]:
+#     mask_list = []
+#     for idx2 in range(6):
+#         if idx2 == idx:
+#             mid_mask = max_value_predictions<threshold
+#             mask_list.append(mid_mask.astype(tf.float32))
+#         else:
+#             mask_list.append(tf.zeros((1024,1024)))
+
+#     mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
+
+#     indexes = tf.where(mask)
+#     values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
+
+#     y_pred = tf.tensor_scatter_nd_update(y_pred,indexes,values_updates.astype(tf.float32))
+
+# squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
+# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
+# # plt.show()
+# # show the flattened image
+# plt.imshow(squeezed_prediction[0,:,:])
+# print(np.unique(squeezed_prediction))
+# plt.show()
+
+# squeezed_prediction = tf.argmax(y_pred,axis=-1)
+# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
+# # plt.show()
+# # show the flattened image
+# plt.imshow(squeezed_prediction[0,:,:])
+# print(np.unique(squeezed_prediction))
+# plt.show()
+# # %%
+
+# %%