Diff of /uNet_Subclassed.py [000000] .. [3b7fea]

Switch to unified view

a b/uNet_Subclassed.py
1
# %% importing packages
2
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow import keras
6
from tensorflow.keras import layers
7
from tensorflow.keras import mixed_precision
8
from tensorflow.python.ops.numpy_ops import np_config
9
np_config.enable_numpy_behavior()
10
from skimage import measure
11
import cv2 as cv
12
import os
13
import matplotlib.pyplot as plt
14
plt.rcParams['figure.figsize'] = [5, 5]
15
# you can alternatively call this script using this line in the terminal to
16
# address the issue of memory leak when using the dataset.shuffle buffer. Found
17
# at the subsequent link.
18
# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4.5.9 python3 uNet_Subclassed.py
19
20
# https://stackoverflow.com/questions/55211315/memory-leak-with-tf-data/66971031#66971031
21
22
23
# %% Citations
24
#############################################################
25
#############################################################
26
# https://www.tensorflow.org/guide/keras/functional
27
# https://www.tensorflow.org/tutorials/customization/custom_layers
28
# https://keras.io/examples/keras_recipes/tfrecord/
29
# https://arxiv.org/abs/1505.04597
30
# https://www.tensorflow.org/guide/gpu
31
32
# Defining Functions
33
#############################################################
34
#############################################################
35
36
def parse_tf_elements(element):
37
    '''This function is the mapper function for retrieving examples from the
38
       tfrecord'''
39
40
    # create placeholders for all the features in each example
41
    data = {
42
        'height' : tf.io.FixedLenFeature([],tf.int64),
43
        'width' : tf.io.FixedLenFeature([],tf.int64),
44
        'raw_image' : tf.io.FixedLenFeature([],tf.string),
45
        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
46
        'bbox_x' : tf.io.VarLenFeature(tf.float32),
47
        'bbox_y' : tf.io.VarLenFeature(tf.float32),
48
        'bbox_height' : tf.io.VarLenFeature(tf.float32),
49
        'bbox_width' : tf.io.VarLenFeature(tf.float32)
50
    }
51
52
    # pull out the current example
53
    content = tf.io.parse_single_example(element, data)
54
55
    # pull out each feature from the example 
56
    height = content['height']
57
    width = content['width']
58
    raw_seg = content['raw_seg']
59
    raw_image = content['raw_image']
60
    bbox_x = content['bbox_x']
61
    bbox_y = content['bbox_y']
62
    bbox_height = content['bbox_height']
63
    bbox_width = content['bbox_width']
64
65
    # convert the images to uint8, and reshape them accordingly
66
    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
67
    image = tf.reshape(image,shape=[height,width,3])
68
    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)
69
    segmentation = tf.reshape(segmentation,shape=[height,width,1])
70
    one_hot_seg = tf.one_hot(tf.squeeze(segmentation),7,axis=-1)
71
72
73
74
75
    
76
77
    # there currently is a bug with returning the bbox, but isn't necessary
78
    # to fix for creating the initial uNet for segmentation exploration
79
    
80
    # bbox = [bbox_x,bbox_y,bbox_height,bbox_width]
81
    return(image,one_hot_seg)
82
83
#############################################################
84
85
class EncoderBlock(layers.Layer):
86
    '''This function returns an encoder block with two convolutional layers and 
87
       an option for returning both a max-pooled output with a stride and pool 
88
       size of (2,2) and the output of the second convolution for skip 
89
       connections implemented later in the network during the decoding 
90
       section. All padding is set to "same" for cleanliness.
91
       
92
       When initializing it receives the number of filters to be used in both
93
       of the convolutional layers as well as the kernel size and stride for 
94
       those same layers. It also receives the trainable variable for use with
95
       the batch normalization layers.'''
96
97
    def __init__(self,
98
                 filters,
99
                 kernel_size=(3,3),
100
                 strides=(1,1),
101
                 trainable=True,
102
                 name='encoder_block',
103
                 **kwargs):
104
105
        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
106
        # When initializing this object receives a trainable parameter for
107
        # freezing the convolutional layers. 
108
109
        # including the image normalization within the network for easier image
110
        # processing during inference
111
        self.image_normalization = layers.Rescaling(scale=1./255)
112
113
        # below creates the first of two convolutional layers
114
        self.conv1 = layers.Conv2D(filters=filters,
115
                      kernel_size=kernel_size,
116
                      strides=strides,
117
                      padding='same',
118
                      name='encoder_conv1',
119
                      trainable=trainable)
120
121
        # second of two convolutional layers
122
        self.conv2 = layers.Conv2D(filters=filters,
123
                      kernel_size=kernel_size,
124
                      strides=strides,
125
                      padding='same',
126
                      name='encoder_conv2',
127
                      trainable=trainable)
128
129
        # creates the max-pooling layer for downsampling the image.
130
        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
131
                                    strides=(2,2),
132
                                    padding='same',
133
                                    name='enc_pool')
134
135
        # ReLU layer for activations.
136
        self.ReLU = layers.ReLU()
137
        
138
        # both batch normalization layers for use with their corresponding
139
        # convolutional layers.
140
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
141
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
142
143
    def call(self,input,normalization=False,training=True,include_pool=True):
144
        
145
        # first conv of the encoder block
146
        if normalization:
147
            x = self.image_normalization(input)
148
            x = self.conv1(x)
149
        else:
150
            x = self.conv1(input)
151
152
        x = self.batch_norm1(x,training=training)
153
        x = self.ReLU(x)
154
155
        # second conv of the encoder block
156
        x = self.conv2(x)
157
        x = self.batch_norm2(x,training=training)
158
        x = self.ReLU(x)
159
        
160
        # calculate and include the max pooling layer if include_pool is true.
161
        # This output is used for the skip connections later in the network.
162
        if include_pool:
163
            pooled_x = self.enc_pool(x)
164
            return(x,pooled_x)
165
166
        else:
167
            return(x)
168
169
170
#############################################################
171
172
class DecoderBlock(layers.Layer):
173
    '''This function returns a decoder block that when called receives both an
174
       input and a "skip connection". The input is passed to the 
175
       "up convolution" or transpose conv layer to double the dimensions before
176
       being concatenated with its associated skip connection from the encoder
177
       section of the network. All padding is set to "same" for cleanliness. 
178
       The decoder block also has an option for including an additional 
179
       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
180
       produces the logits for the one-hot encoded ground truth. 
181
       
182
       When initializing it receives the number of filters to be used in the
183
       up convolutional layer as well as the other two forward convolutions. 
184
       The received kernel_size and stride is used for the forward convolutions,
185
       with the up convolution kernel and stride set to be (2,2).'''
186
    def __init__(self,
187
                 filters,
188
                 trainable=True,
189
                 kernel_size=(3,3),
190
                 strides=(1,1),
191
                 name='DecoderBlock',
192
                 **kwargs):
193
194
        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
195
196
        # creating the up convolution layer
197
        self.up_conv = layers.Conv2DTranspose(filters=filters,
198
                                              kernel_size=(2,2),
199
                                              strides=(2,2),
200
                                              padding='same',
201
                                              name='decoder_upconv',
202
                                              trainable=trainable)
203
204
        # the first of two forward convolutional layers
205
        self.conv1 = layers.Conv2D(filters=filters,
206
                                   kernel_size=kernel_size,
207
                                   strides=strides,
208
                                   padding='same',
209
                                   name ='decoder_conv1',
210
                                   trainable=trainable)
211
212
        # second convolutional layer
213
        self.conv2 = layers.Conv2D(filters=filters,
214
                                   kernel_size=kernel_size,
215
                                   strides=strides,
216
                                   padding='same',
217
                                   name ='decoder_conv2',
218
                                   trainable=trainable)
219
220
        # this creates the output prediction logits layer.
221
        self.seg_out = layers.Conv2D(filters=7,
222
                        kernel_size=(1,1),
223
                        name='conv_feature_map')
224
225
        # ReLU for activation of all above layers
226
        self.ReLU = layers.ReLU()
227
        
228
        # the individual batch normalization layers for their respective 
229
        # convolutional layers.
230
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
231
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
232
233
234
    def call(self,input,skip_conn,training=True,segmentation=False):
235
        
236
        up = self.up_conv(input) # perform image up convolution
237
        # concatenate the input and the skip_conn along the features axis
238
        concatenated = layers.concatenate([up,skip_conn],axis=-1)
239
240
        # first convolution 
241
        x = self.conv1(concatenated)
242
        x = self.batch_norm1(x,training=training)
243
        x = self.ReLU(x)
244
245
        # second convolution
246
        x = self.conv2(x)
247
        x = self.batch_norm2(x,training=training)
248
        x = self.ReLU(x)
249
250
        # if segmentation is True, then run the segmentation (1,1) convolution
251
        # and use the Softmax to produce a probability distribution.
252
        if segmentation:
253
            seg = self.seg_out(x)
254
            # deliberately set as "float32" to ensure proper calculation if 
255
            # switching to mixed precision for efficiency
256
            prob = layers.Softmax(dtype='float32')(seg)
257
            return(prob)
258
259
        else:
260
            return(x)
261
262
#############################################################
263
264
class uNet(keras.Model):
265
    '''This is a sub-classed model that uses the encoder and decoder blocks
266
       defined above to create a custom unet. The differences from the original 
267
       paper include a variable filter scalar (filter_multiplier), batch 
268
       normalization between each convolutional layer and the associated ReLU 
269
       activation, as well as feature normalization implemented in the first 
270
       layer of the network.'''
271
    def __init__(self,filter_multiplier=2,**kwargs):
272
        super(uNet,self).__init__()
273
        
274
        # Defining encoder blocks
275
        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
276
                                           name='Enc1')
277
        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
278
                                           name='Enc2')
279
        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
280
                                           name='Enc3')
281
        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
282
                                           name='Enc4')
283
        # self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
284
        #                                    name='Enc5')
285
286
        # Defining decoder blocks. The names are in reverse order to make it 
287
        # (hopefully) easier to understand which skip connections are associated
288
        # with which decoder layers.
289
        # self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
290
        #                                    name='Dec4')
291
        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
292
                                           name='Dec3')
293
        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
294
                                           name='Dec2')
295
        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
296
                                           name='Dec1')
297
298
299
    def call(self,inputs,training):
300
301
        # encoder    
302
        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
303
        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
304
        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
305
306
        # enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
307
        # enc5 = self.encoder_block5(input=enc4_pool,
308
        #                            include_pool=False,
309
        #                            training=training)
310
311
        enc4 = self.encoder_block4(input=enc3_pool,
312
                                   include_pool=False,
313
                                   training=training)
314
315
316
        # decoder
317
        # dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
318
        dec3 = self.decoder_block3(input=enc4,skip_conn=enc3,training=training)
319
        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
320
        seg_logits_out = self.decoder_block1(input=dec2,
321
                                             skip_conn=enc1,
322
                                             segmentation=True,
323
                                             training=training)
324
325
        return(seg_logits_out)
326
327
#############################################################
328
329
def load_dataset(file_names):
330
    '''Receives a list of file names from a folder that contains tfrecord files
331
       compiled previously. Takes these names and creates a tensorflow dataset
332
       from them.'''
333
334
    ignore_order = tf.data.Options()
335
    ignore_order.experimental_deterministic = False
336
    dataset = tf.data.TFRecordDataset(file_names)
337
338
    # you can shard the dataset if you like to reduce the size when necessary
339
    dataset = dataset.shard(num_shards=8,index=2)
340
    
341
    # order in the file names doesn't really matter, so ignoring it
342
    dataset = dataset.with_options(ignore_order)
343
344
    # mapping the dataset using the parse_tf_elements function defined earlier
345
    dataset = dataset.map(parse_tf_elements,num_parallel_calls=1)
346
    
347
    return(dataset)
348
349
#############################################################
350
351
def get_dataset(file_names,batch_size):
352
    '''Receives a list of file names of tfrecord shards from a dataset as well
353
       as a batch size for the dataset.'''
354
    
355
    # uses the load_dataset function to retrieve the files and put them into a 
356
    # dataset.
357
    dataset = load_dataset(file_names)
358
    
359
    # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free
360
    # to alter as fits your hardware.
361
    dataset = dataset.shuffle(300)
362
363
    # adding the batch size to the dataset
364
    dataset = dataset.batch(batch_size=batch_size)
365
366
    return(dataset)
367
368
#############################################################
369
370
def weighted_cce_loss(y_true,y_pred):
371
    '''Yes, this function essentially does what the "fit" argument 
372
       "class_weight" does when training a network. I had to create this 
373
       separate custom loss function because aparently when using tfrecord files
374
       for reading your dataset a check is performed comparing the input, ground
375
       truth, and weights values to each other. However, a comparison between 
376
       the empty None that is passed during the build call of the model and the
377
       weight array/dictionary returns an error. Thus, here is a custom loss 
378
       function that applies a weighting to the different classes based on the 
379
       distribution of the classes within the entire dataset. Note that the 
380
       weights used here are only from the training set, not including images
381
       from the testing and validation sets, to prevent any over-eager reviewers
382
       from screaming "information leak!!"
383
       Just kidding, it is first to prevent an information leak, and second to 
384
       preempt over-eager reviewers.'''
385
386
387
       
388
389
    # weights for each class, as background, connective, muscle, and vasculature
390
    # weights = [0, 2.95559004,   7.33779693,  12.87393959, 1000.43461107, 1200.63780628, 20.23600735]
391
    # weights = [0, 0.80284233, 1.68275694, 2.63726432, 3000.8055788, 2000.26933614, 100.30741485] # last good run
392
    # [0,2.72403952, 2.81034368, 4.36437716, 36.66264202, 108.40694198, 87.39903838]
393
    weights = [0,2.72403952, 2.81034368, 4.36437716, 36.66264202, 108.40694198, 87.39903838]
394
395
    count = 0
396
397
398
    all_weights_for_loss = tf.expand_dims(tf.ones((1024,1024)).astype(tf.float64), axis=0)
399
400
    for image in y_true:
401
        weights_for_image = tf.ones((1024,1024)).astype(tf.float64)
402
403
        for idx,weight in enumerate(weights):
404
            mask = image[:,:,idx]
405
            mask.set_shape((1024,1024))
406
            indexes = tf.where(mask)
407
            values_mask = mask*weights[idx]
408
409
            values_updates = tf.boolean_mask(values_mask,mask).astype(tf.double)
410
411
            weights_for_image = tf.tensor_scatter_nd_update(weights_for_image,indexes,values_updates)
412
413
        if count == 0:
414
            all_weights_for_loss = tf.expand_dims(weights_for_image, axis=0)
415
        else:
416
            all_weights_for_loss = tf.concat([all_weights_for_loss,tf.expand_dims(weights_for_image, axis=0)],axis=0)
417
        count += 1
418
419
    cce = tf.keras.losses.CategoricalCrossentropy()
420
    cce_loss = cce(y_true,y_pred,all_weights_for_loss)
421
422
    return(cce_loss)
423
424
425
426
#############################################################
427
#############################################################
428
# %% Setting up the GPU, and setting memory growth to true so that it is easier
429
# to see how much memory the training process is taking up exactly. This code is
430
# from a tensorflow tutorial. 
431
432
gpus = tf.config.list_physical_devices('GPU')
433
if gpus:
434
  try:
435
    for gpu in gpus:
436
      tf.config.experimental.set_memory_growth(gpu, True)
437
    logical_gpus = tf.config.list_logical_devices('GPU')
438
439
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
440
  except RuntimeError as e:
441
    print(e)
442
443
# use this to set mixed precision for higher efficiency later if you would like
444
# mixed_precision.set_global_policy('mixed_float16')
445
446
# %% setting up datasets and building model
447
448
# directory where the dataset shards are stored
449
os.chdir('/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/')
450
training_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/train'
451
val_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/validate'
452
testing_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_5/test'
453
454
# only get the file names that follow the shard naming convention
455
train_files = tf.io.gfile.glob(training_directory + \
456
                              "/shard_*_of_*.tfrecords")
457
val_files = tf.io.gfile.glob(val_directory + \
458
                              "/shard_*_of_*.tfrecords")
459
test_files = tf.io.gfile.glob(testing_directory + \
460
                              "/shard_*_of_*.tfrecords")
461
462
# create the datasets. Because of how batches are run for training, we set
463
# the dataset to repeat() because the batches and epochs are altered from 
464
# standard practice to fit on graphics cards and provide more meaningful and 
465
# frequent updates to the console.
466
training_dataset = get_dataset(train_files,batch_size=1)
467
training_dataset = training_dataset.repeat()
468
validation_dataset = get_dataset(val_files,batch_size = 1)
469
# testing has a batch size of 1 to facilitate visualization of predictions
470
testing_dataset = get_dataset(test_files,batch_size=1)
471
472
# explicitly puts the model on the GPU to show how large it is. 
473
gpus = tf.config.list_logical_devices('GPU')
474
with tf.device(gpus[0].name):
475
    # filter multiplier provided creates largest filter depth of 256 with a 
476
    # multiplier of 8. 
477
    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
478
    unet = uNet(filter_multiplier=32,)
479
    # build with input image size of 512*512
480
    out = unet(sample_data)
481
    unet.summary()
482
# %%
483
# running network eagerly because it allows us to use convert a tensor to a
484
# numpy array to help with the weighted loss calculation.
485
unet.compile(
486
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
487
    loss=tf.keras.losses.CategoricalCrossentropy(),
488
    run_eagerly=True,
489
    metrics=[tf.keras.metrics.Precision(name='precision'),
490
                tf.keras.metrics.Recall(name='recall')]
491
)
492
493
# %%
494
class SanityCheck(keras.callbacks.Callback):
495
496
    def __init__(self, testing_images):
497
        super(SanityCheck, self).__init__()
498
        self.testing_images = testing_images
499
500
501
    def on_epoch_end(self,epoch, logs=None):
502
        for image_pair in self.testing_images:
503
            out = self.model.predict(image_pair[0],verbose=0)
504
            image = cv.cvtColor(np.squeeze(np.asarray(image_pair[0]).copy()),cv.COLOR_BGR2RGB)
505
            squeezed_gt = tf.argmax(image_pair[1],axis=-1)
506
            squeezed_prediction = tf.argmax(out,axis=-1)
507
508
            vasc_gt = np.squeeze(image_pair[1][0,:,:,4])
509
            neural_gt = np.squeeze(image_pair[1][0,:,:,5])
510
            vasc_pred = np.squeeze(out[0,:,:,4])
511
            neural_pred = np.squeeze(out[0,:,:,5])
512
513
            fig,ax = plt.subplots(1,3)
514
515
            ax[0].imshow(image)
516
            ax[1].imshow(squeezed_gt[0,:,:],vmin=0, vmax=7)
517
            ax[2].imshow(squeezed_prediction[0,:,:],vmin=0, vmax=7)
518
            # ax[1].imshow(squeezed_gt[0,:,:]==4)
519
            # ax[2].imshow(squeezed_prediction[0,:,:]==4)
520
            plt.show()
521
            print(np.unique(squeezed_gt[0,:,:]))
522
            print(np.unique(squeezed_prediction[0,:,:]))
523
524
525
test_images = []
526
for sample in testing_dataset.take(5):
527
    #print(sample[0].shape)
528
    test_images.append([sample[0],sample[1]])
529
    
530
# %%
531
532
# creating callbacks
533
sanity_check = SanityCheck(test_images)
534
535
def schedule(epoch, lr):
536
    if (epoch % 3) == 0:
537
        return(lr*0.7)
538
    else:
539
        return(lr)
540
541
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
542
543
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
544
                                                 mode='min',
545
                                                 factor=0.8,
546
                                                 patience=5,
547
                                                 min_lr=0.000001,
548
                                                 verbose=True,
549
                                                 min_delta=0.01,)
550
551
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('unet_seg_weights.{epoch:02d}-{val_loss:.2f}-{val_precision:.2f}-{val_recall:.2f}.h5',
552
                                                   save_weights_only=True,
553
                                                   monitor='loss',
554
                                                   mode='min',
555
                                                   verbose=True)
556
557
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,
558
                                                     monitor='loss',
559
                                                     mode='min',
560
                                                     restore_best_weights=True,
561
                                                     verbose=True,
562
                                                     min_delta=0.01)
563
564
# setting the number of batches to iterate through each epoch to a value much
565
# lower than what it normaly would be so that we can actually see what is going
566
# on with the network, as well as have a meaningful early stopping.
567
568
569
# %% fit the network!
570
# unet.load_weights('./unet_seg_weights.50-0.64-0.93-0.91.h5')
571
num_steps = 100
572
573
weights = {0:0,1:2.72403952,2:2.81034368,3:4.36437716,4:36.66264202, 5:108.40694198, 6:87.39903838}
574
575
history = unet.fit(training_dataset,
576
                   epochs=100,
577
                   steps_per_epoch=num_steps,
578
                   validation_data=validation_dataset,
579
                   class_weight=weights,
580
                   callbacks=[checkpoint_cb,
581
                              early_stopping_cb,
582
                              reduce_lr,
583
                              sanity_check,])
584
# %%
585
586
587
588
# %%
589
# evaluate the network after loading the weights
590
# unet.load_weights('./unet_seg_weights.49-0.52-0.94-0.92.h5')
591
results = unet.evaluate(testing_dataset)
592
print(results)
593
# %%
594
# extracting loss vs epoch
595
loss = history.history['loss']
596
val_loss = history.history['val_loss']
597
# extracting precision vs epoch
598
precision = history.history['precision']
599
val_precision = history.history['val_precision']
600
# extracting recall vs epoch
601
recall = history.history['recall']
602
val_recall = history.history['val_recall']
603
604
epochs = range(len(loss))
605
606
figs, axes = plt.subplots(3,1)
607
608
# plotting loss and validation loss
609
axes[0].plot(epochs,loss)
610
axes[0].plot(epochs,val_loss)
611
axes[0].legend(['loss','val_loss'])
612
axes[0].set(xlabel='epochs',ylabel='crossentropy loss')
613
614
# plotting precision and validation precision
615
axes[1].plot(epochs,precision)
616
axes[1].plot(epochs,val_precision)
617
axes[1].legend(['precision','val_precision'])
618
axes[1].set(xlabel='epochs',ylabel='precision')
619
620
# plotting recall validation recall
621
axes[2].plot(epochs,recall)
622
axes[2].plot(epochs,val_recall)
623
axes[2].legend(['recall','val_recall'])
624
axes[2].set(xlabel='epochs',ylabel='recall')
625
626
627
628
# %% exploring the predictions to better understand what the network is doing
629
630
images = []
631
gt = []
632
predictions = []
633
634
# taking out 10 of the next samples from the testing dataset and iterating 
635
# through them
636
for sample in testing_dataset.take(10):
637
    # make sure it is producing the correct dimensions
638
    print(sample[0].shape)
639
    # take the image and convert it back to RGB, store in list
640
    image = sample[0]
641
    image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB)
642
    images.append(image)
643
    # extract the ground truth and store in list
644
    ground_truth = sample[1]
645
    gt.append(ground_truth)
646
    # perform inference
647
    out = unet.predict(sample[0])
648
    predictions.append(out)
649
    # show the original input image
650
    plt.imshow(image)
651
    plt.show()
652
    # flatten the ground truth from one-hot encoded along the last axis, and 
653
    # show the resulting image
654
    squeezed_gt = tf.argmax(ground_truth,axis=-1)
655
    squeezed_prediction = tf.argmax(out,axis=-1)
656
    plt.imshow(squeezed_gt[0,:,:],vmin=0, vmax=6)
657
    # print the number of classes in this tile
658
    print(np.unique(squeezed_gt))
659
    plt.show()
660
    # show the flattened predictions
661
    plt.imshow(squeezed_prediction[0,:,:],vmin=0, vmax=6)
662
    print(np.unique(squeezed_prediction))
663
    plt.show()
664
665
# %%
666
# select one of the images cycled through above to investigate further
667
image_to_investigate = 6
668
669
# show the original image
670
plt.imshow(images[image_to_investigate])
671
plt.show()
672
673
# show the ground truth for this tile
674
squeezed_gt = tf.argmax(gt[image_to_investigate],axis=-1)
675
plt.imshow(squeezed_gt[0,:,:])
676
# print the number of unique classes in the ground truth
677
print(np.unique(squeezed_gt))
678
plt.show()
679
 # flatten the prediction and show the probability distribution
680
squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
681
plt.imshow(predictions[image_to_investigate][0,:,:,3])
682
plt.show()
683
# show the flattened image
684
plt.imshow(squeezed_prediction[0,:,:])
685
print(np.unique(squeezed_prediction))
686
plt.show()
687
688
# %%