Diff of /uNet_Subclassed_SCCE.py [000000] .. [3b7fea]

Switch to unified view

a b/uNet_Subclassed_SCCE.py
1
# %% importing packages
2
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow import keras
6
from tensorflow.keras import layers
7
from tensorflow.keras import mixed_precision
8
from tensorflow.python.ops.numpy_ops import np_config
9
np_config.enable_numpy_behavior()
10
from skimage import measure
11
import cv2 as cv
12
import os
13
import matplotlib.pyplot as plt
14
plt.rcParams['figure.figsize'] = [5, 5]
15
# you can alternatively call this script using this line in the terminal to
16
# address the issue of memory leak when using the dataset.shuffle buffer. Found
17
# at the subsequent link.
18
# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4.5.9 python3 uNet_Subclassed.py
19
20
# https://stackoverflow.com/questions/55211315/memory-leak-with-tf-data/66971031#66971031
21
22
23
# %% Citations
24
#############################################################
25
#############################################################
26
# https://www.tensorflow.org/guide/keras/functional
27
# https://www.tensorflow.org/tutorials/customization/custom_layers
28
# https://keras.io/examples/keras_recipes/tfrecord/
29
# https://arxiv.org/abs/1505.04597
30
# https://www.tensorflow.org/guide/gpu
31
32
# Defining Functions
33
#############################################################
34
#############################################################
35
36
def parse_tf_elements(element):
37
    '''This function is the mapper function for retrieving examples from the
38
       tfrecord'''
39
40
    # create placeholders for all the features in each example
41
    data = {
42
        'height' : tf.io.FixedLenFeature([],tf.int64),
43
        'width' : tf.io.FixedLenFeature([],tf.int64),
44
        'raw_image' : tf.io.FixedLenFeature([],tf.string),
45
        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
46
        'bbox_x' : tf.io.VarLenFeature(tf.float32),
47
        'bbox_y' : tf.io.VarLenFeature(tf.float32),
48
        'bbox_height' : tf.io.VarLenFeature(tf.float32),
49
        'bbox_width' : tf.io.VarLenFeature(tf.float32),
50
        'name' : tf.io.FixedLenFeature([],tf.string),
51
    }
52
53
    # pull out the current example
54
    content = tf.io.parse_single_example(element, data)
55
56
    # pull out each feature from the example 
57
    height = content['height']
58
    width = content['width']
59
    raw_seg = content['raw_seg']
60
    raw_image = content['raw_image']
61
    name = content['name']
62
63
    # note that the bounding boxes are included here, but are not used. These 
64
    # were included in the dataset for future use if I wanted to put together
65
    # something like YOLO for practice. Could be used later, but also haven't 
66
    # been thoroughly tested, so could be buggy and should be vetted.
67
    bbox_x = content['bbox_x']
68
    bbox_y = content['bbox_y']
69
    bbox_height = content['bbox_height']
70
    bbox_width = content['bbox_width']
71
72
    # convert the images to uint8, and reshape them accordingly
73
    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
74
    image = tf.reshape(image,shape=[height,width,3])
75
    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)-1
76
    # This is including the class weights in the parser, enabling them to be
77
    # used by the loss function to weight the loss and accuracy metrics.
78
    # Note that the last two are divided by two to prevent them from being over
79
    # segmented, which they were.
80
    # [2.72403952, 2.81034368, 4.36437716, 36.66264202, 108.40694198, 87.39903838]
81
    weights = [2.15248481,
82
               3.28798466, 
83
               5.18559616, 
84
               46.96594578*3, 
85
               130.77512742*2, 
86
               105.23678672/2]
87
    weights = np.divide(weights,sum(weights))
88
    
89
    # the weights are calculated by the tf_record_weight_determination.py file,
90
    # and are related to the percentages of each class in the dataset.
91
    sample_weights = tf.gather(weights, indices=tf.cast(segmentation, tf.int32))
92
93
    return(image,segmentation,sample_weights)
94
95
#############################################################
96
97
class EncoderBlock(layers.Layer):
98
    '''This function returns an encoder block with two convolutional layers and 
99
       an option for returning both a max-pooled output with a stride and pool 
100
       size of (2,2) and the output of the second convolution for skip 
101
       connections implemented later in the network during the decoding 
102
       section. All padding is set to "same" for cleanliness.
103
       
104
       When initializing it receives the number of filters to be used in both
105
       of the convolutional layers as well as the kernel size and stride for 
106
       those same layers. It also receives the trainable variable for use with
107
       the batch normalization layers.'''
108
109
    def __init__(self,
110
                 filters,
111
                 kernel_size=(3,3),
112
                 strides=(1,1),
113
                 trainable=True,
114
                 name='encoder_block',
115
                 **kwargs):
116
117
        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
118
        # When initializing this object receives a trainable parameter for
119
        # freezing the convolutional layers. 
120
121
        # including the image normalization within the network for easier image
122
        # processing during inference
123
        self.image_normalization = layers.Rescaling(scale=1./255)
124
125
        # below creates the first of two convolutional layers
126
        self.conv1 = layers.Conv2D(filters=filters,
127
                      kernel_size=kernel_size,
128
                      strides=strides,
129
                      padding='same',
130
                      name='encoder_conv1',
131
                      trainable=trainable)
132
133
        # second of two convolutional layers
134
        self.conv2 = layers.Conv2D(filters=filters,
135
                      kernel_size=kernel_size,
136
                      strides=strides,
137
                      padding='same',
138
                      name='encoder_conv2',
139
                      trainable=trainable)
140
141
        # creates the max-pooling layer for downsampling the image.
142
        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
143
                                    strides=(2,2),
144
                                    padding='same',
145
                                    name='enc_pool')
146
147
        # ReLU layer for activations.
148
        self.ReLU = layers.ReLU()
149
        
150
        # both batch normalization layers for use with their corresponding
151
        # convolutional layers.
152
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
153
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
154
155
    def call(self,input,normalization=False,training=True,include_pool=True):
156
        
157
        # first conv of the encoder block
158
        if normalization:
159
            x = self.image_normalization(input)
160
            x = self.conv1(x)
161
        else:
162
            x = self.conv1(input)
163
164
        x = self.batch_norm1(x,training=training)
165
        x = self.ReLU(x)
166
167
        # second conv of the encoder block
168
        x = self.conv2(x)
169
        x = self.batch_norm2(x,training=training)
170
        x = self.ReLU(x)
171
        
172
        # calculate and include the max pooling layer if include_pool is true.
173
        # This output is used for the skip connections later in the network.
174
        if include_pool:
175
            pooled_x = self.enc_pool(x)
176
            return(x,pooled_x)
177
178
        else:
179
            return(x)
180
181
182
#############################################################
183
184
class DecoderBlock(layers.Layer):
185
    '''This function returns a decoder block that when called receives both an
186
       input and a "skip connection". The input is passed to the 
187
       "up convolution" or transpose conv layer to double the dimensions before
188
       being concatenated with its associated skip connection from the encoder
189
       section of the network. All padding is set to "same" for cleanliness. 
190
       The decoder block also has an option for including an additional 
191
       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
192
       produces the logits for the one-hot encoded ground truth. 
193
       
194
       When initializing it receives the number of filters to be used in the
195
       up convolutional layer as well as the other two forward convolutions. 
196
       The received kernel_size and stride is used for the forward convolutions,
197
       with the up convolution kernel and stride set to be (2,2).'''
198
    def __init__(self,
199
                 filters,
200
                 trainable=True,
201
                 kernel_size=(3,3),
202
                 strides=(1,1),
203
                 name='DecoderBlock',
204
                 **kwargs):
205
206
        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
207
208
        # creating the up convolution layer
209
        self.up_conv = layers.Conv2DTranspose(filters=filters,
210
                                              kernel_size=(2,2),
211
                                              strides=(2,2),
212
                                              padding='same',
213
                                              name='decoder_upconv',
214
                                              trainable=trainable)
215
216
        # the first of two forward convolutional layers
217
        self.conv1 = layers.Conv2D(filters=filters,
218
                                   kernel_size=kernel_size,
219
                                   strides=strides,
220
                                   padding='same',
221
                                   name ='decoder_conv1',
222
                                   trainable=trainable)
223
224
        # second convolutional layer
225
        self.conv2 = layers.Conv2D(filters=filters,
226
                                   kernel_size=kernel_size,
227
                                   strides=strides,
228
                                   padding='same',
229
                                   name ='decoder_conv2',
230
                                   trainable=trainable)
231
232
        # this creates the output prediction logits layer.
233
        self.seg_out = layers.Conv2D(filters=6,
234
                        kernel_size=(1,1),
235
                        name='conv_feature_map')
236
237
        # ReLU for activation of all above layers
238
        self.ReLU = layers.ReLU()
239
        
240
        # the individual batch normalization layers for their respective 
241
        # convolutional layers.
242
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
243
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
244
245
246
    def call(self,input,skip_conn,training=True,segmentation=False,prob_dist=True):
247
        
248
        up = self.up_conv(input) # perform image up convolution
249
        # concatenate the input and the skip_conn along the features axis
250
        concatenated = layers.concatenate([up,skip_conn],axis=-1)
251
252
        # first convolution 
253
        x = self.conv1(concatenated)
254
        x = self.batch_norm1(x,training=training)
255
        x = self.ReLU(x)
256
257
        # second convolution
258
        x = self.conv2(x)
259
        x = self.batch_norm2(x,training=training)
260
        x = self.ReLU(x)
261
262
        # if segmentation is True, then run the segmentation (1,1) convolution
263
        # and use the Softmax to produce a probability distribution.
264
        if segmentation:
265
            seg = self.seg_out(x)
266
            # deliberately set as "float32" to ensure proper calculation if 
267
            # switching to mixed precision for efficiency
268
            if prob_dist:
269
                seg = layers.Softmax(dtype='float32')(seg)
270
271
            return(seg)
272
273
        else:
274
            return(x)
275
276
#############################################################
277
278
class uNet(keras.Model):
279
    '''This is a sub-classed model that uses the encoder and decoder blocks
280
       defined above to create a custom unet. The differences from the original 
281
       paper include a variable filter scalar (filter_multiplier), batch 
282
       normalization between each convolutional layer and the associated ReLU 
283
       activation, as well as feature normalization implemented in the first 
284
       layer of the network.'''
285
    def __init__(self,filter_multiplier=2,**kwargs):
286
        super(uNet,self).__init__()
287
        
288
        # Defining encoder blocks
289
        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
290
                                           name='Enc1')
291
        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
292
                                           name='Enc2')
293
        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
294
                                           name='Enc3')
295
        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
296
                                           name='Enc4')
297
        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
298
                                           name='Enc5')
299
300
        # Defining decoder blocks. The names are in reverse order to make it 
301
        # (hopefully) easier to understand which skip connections are associated
302
        # with which decoder layers.
303
        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
304
                                           name='Dec4')
305
        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
306
                                           name='Dec3')
307
        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
308
                                           name='Dec2')
309
        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
310
                                           name='Dec1')
311
312
313
    def call(self,inputs,training,predict=False,threshold=3):
314
315
        # encoder    
316
        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
317
        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
318
        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
319
        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
320
        enc5 = self.encoder_block5(input=enc4_pool,
321
                                   include_pool=False,
322
                                   training=training)
323
324
        # enc4 = self.encoder_block4(input=enc3_pool,
325
        #                            include_pool=False,
326
        #                            training=training)
327
328
329
        # decoder
330
        dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
331
        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
332
        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
333
        prob_dist_out = self.decoder_block1(input=dec2,
334
                                            skip_conn=enc1,
335
                                            segmentation=True,
336
                                            training=training)
337
        if predict:
338
            seg_logits_out = self.decoder_block1(input=dec2,
339
                                                 skip_conn=enc1,
340
                                                 segmentation=True,
341
                                                 training=training,
342
                                                 prob_dist=False)
343
344
        # This prediction is included to allow one to seta threshold for the 
345
        # uncertainty, deemed an arbitrary value that corresponds to the 
346
        # maximum value of the logits predicted at a specific point in the 
347
        # image. It only includes predictions for the vascular and neural 
348
        # tissues if they are above the confidence threshold, if they are below
349
        # the threshold the predictions are defaulted to muscle, connective,
350
        # or background.
351
        
352
        if predict:
353
            # rename the value for consistency and write protection.
354
            y_pred = seg_logits_out
355
            pred_shape = (1,1024,1024,6)
356
            # Getting an image-sized preliminary segmentation prediction
357
            squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
358
359
            # initializing the variable used for storing the maximum logits at 
360
            # each pixel location.
361
            max_value_predictions = tf.zeros((1024,1024))
362
363
            # cycle through all the classes 
364
            for idx in range(6):
365
                
366
                # current class logits
367
                current_slice = tf.squeeze(y_pred[:,:,:,idx])
368
                # find the locations where this class is predicted
369
                current_indices = squeezed_prediction == idx
370
                # define the shape so that this function can run in graph mode
371
                # and not need eager execution.
372
                current_indices.set_shape((1024,1024))
373
                # Get the indices of where the idx class is predicted
374
                indices = tf.where(squeezed_prediction == idx)
375
                # get the output of boolean_mask to enable scatter update of the
376
                # tensor. This is required because tensors do not support 
377
                # mask indexing.
378
                values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
379
                # Place the maximum logit values at each point in an 
380
                # image-size matrix, indicating the confidence in the prediction
381
                # at each pixel. 
382
                max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
383
            
384
            for idx in [3,4]:
385
                mask_list = []
386
                for idx2 in range(6):
387
                    if idx2 == idx:
388
                        mid_mask = max_value_predictions<threshold
389
                        mask_list.append(mid_mask.astype(tf.float32))
390
                    else:
391
                        mask_list.append(tf.zeros((1024,1024)))
392
393
                mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
394
395
                indexes = tf.where(mask)
396
                values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
397
398
                seg_logits_out = tf.tensor_scatter_nd_update(seg_logits_out,indexes,values_updates.astype(tf.float32))
399
                prob_dist_out = layers.Softmax(dtype='float32')(seg_logits_out)
400
            # print("updated logits!")
401
402
403
            
404
        return(prob_dist_out)
405
406
407
    # def test_step(self, data):
408
        
409
    #     threshold = 3
410
    #     x, y, weight = data
411
    #     pred_shape = (1,1024,1024,6)
412
413
    #     y_pred = self(x,training=False)
414
415
    #     squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
416
417
    #     max_value_predictions = tf.zeros((1024,1024))
418
419
    #     for idx in range(6):
420
421
    #         current_slice = tf.squeeze(y_pred[:,:,:,idx])
422
    #         current_indices = squeezed_prediction == idx
423
    #         current_indices.set_shape((1024,1024))
424
    #         indices = tf.where(squeezed_prediction == idx)
425
    #         values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
426
    #         max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
427
        
428
    #     for idx in [3,4]:
429
    #         mask_list = []
430
    #         for idx2 in range(6):
431
    #             if idx2 == idx:
432
    #                 mid_mask = max_value_predictions<threshold
433
    #                 mask_list.append(mid_mask.astype(tf.float32))
434
    #             else:
435
    #                 mask_list.append(tf.zeros((1024,1024)))
436
437
    #         mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
438
439
    #         indexes = tf.where(mask)
440
    #         values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
441
442
    #         y_pred = tf.tensor_scatter_nd_update(y_pred,indexes,values_updates.astype(tf.float32))
443
444
    #     self.compiled_metrics.update_state(y, y_pred, sample_weight=weight)
445
    #     self.compiled_loss(y, y_pred, sample_weight=weight)
446
447
    #     return {m.name: m.result() for m in self.metrics}
448
449
#############################################################
450
451
class SanityCheck(keras.callbacks.Callback):
452
453
    def __init__(self, testing_images):
454
        super(SanityCheck, self).__init__()
455
        self.testing_images = testing_images
456
457
458
    def on_epoch_end(self,epoch, logs=None):
459
        for image_pair in self.testing_images:
460
            out = self.model.predict(image_pair[0],verbose=0)
461
            image = cv.cvtColor(np.squeeze(np.asarray(image_pair[0]).copy()),cv.COLOR_BGR2RGB)
462
            squeezed_gt = image_pair[1][0,:,:]
463
            squeezed_prediction = tf.argmax(out,axis=-1)
464
465
            fig,ax = plt.subplots(1,3)
466
467
            ax[0].imshow(image)
468
            ax[1].imshow(squeezed_gt,vmin=0, vmax=5)
469
            ax[2].imshow(squeezed_prediction[0,:,:],vmin=0, vmax=5)
470
471
            plt.show()
472
            print(np.unique(squeezed_gt))
473
            print(np.unique(squeezed_prediction[0,:,:]))
474
475
476
#############################################################
477
478
def load_dataset(file_names):
479
    '''Receives a list of file names from a folder that contains tfrecord files
480
       compiled previously. Takes these names and creates a tensorflow dataset
481
       from them.'''
482
483
    ignore_order = tf.data.Options()
484
    ignore_order.experimental_deterministic = False
485
    dataset = tf.data.TFRecordDataset(file_names)
486
487
    # you can shard the dataset if you like to reduce the size when necessary
488
    dataset = dataset.shard(num_shards=8,index=2)
489
    
490
    # order in the file names doesn't really matter, so ignoring it
491
    dataset = dataset.with_options(ignore_order)
492
493
    # mapping the dataset using the parse_tf_elements function defined earlier
494
    dataset = dataset.map(parse_tf_elements,num_parallel_calls=1)
495
    
496
    return(dataset)
497
498
#############################################################
499
500
def get_dataset(file_names,batch_size):
501
    '''Receives a list of file names of tfrecord shards from a dataset as well
502
       as a batch size for the dataset.'''
503
    
504
    # uses the load_dataset function to retrieve the files and put them into a 
505
    # dataset.
506
    dataset = load_dataset(file_names)
507
    
508
    # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free
509
    # to alter as fits your hardware.
510
    dataset = dataset.shuffle(300)
511
512
    # adding the batch size to the dataset
513
    dataset = dataset.batch(batch_size=batch_size)
514
515
    return(dataset)
516
517
518
#############################################################
519
#############################################################
520
# %% Setting up the GPU, and setting memory growth to true so that it is easier
521
# to see how much memory the training process is taking up exactly. This code is
522
# from a tensorflow tutorial. 
523
524
gpus = tf.config.list_physical_devices('GPU')
525
if gpus:
526
  try:
527
    for gpu in gpus:
528
      tf.config.experimental.set_memory_growth(gpu, True)
529
    logical_gpus = tf.config.list_logical_devices('GPU')
530
531
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
532
  except RuntimeError as e:
533
    print(e)
534
535
# use this to set mixed precision for higher efficiency later if you would like
536
# mixed_precision.set_global_policy('mixed_float16')
537
538
# %% setting up datasets and building model
539
540
# directory where the dataset shards are stored
541
home_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_6'
542
training_directory = home_directory + '/train'
543
val_directory = home_directory + '/validate'
544
testing_directory = home_directory + '/test'
545
546
os.chdir(home_directory)
547
548
# only get the file names that follow the shard naming convention
549
train_files = tf.io.gfile.glob(training_directory + \
550
                              "/shard_*_of_*.tfrecords")
551
val_files = tf.io.gfile.glob(val_directory + \
552
                              "/shard_*_of_*.tfrecords")
553
test_files = tf.io.gfile.glob(testing_directory + \
554
                              "/shard_*_of_*.tfrecords")
555
556
# create the datasets. Because of how batches are run for training, we set
557
# the dataset to repeat() because the batches and epochs are altered from 
558
# standard practice to fit on graphics cards and provide more meaningful and 
559
# frequent updates to the console.
560
training_dataset = get_dataset(train_files,batch_size=3)
561
training_dataset = training_dataset.repeat()
562
validation_dataset = get_dataset(val_files,batch_size = 3)
563
# testing has a batch size of 1 to facilitate visualization of predictions
564
testing_dataset = get_dataset(test_files,batch_size=1)
565
566
# explicitly puts the model on the GPU to show how large it is. 
567
gpus = tf.config.list_logical_devices('GPU')
568
with tf.device(gpus[0].name):
569
    # filter multiplier provided creates largest filter depth of 256 with a 
570
    # multiplier of 8. 
571
    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
572
    unet = uNet(filter_multiplier=12,) # 12 is the magic number
573
    # build with input image size of 512*512
574
    out = unet(sample_data)
575
    unet.summary()
576
# %%
577
578
unet.compile(
579
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
580
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
581
    run_eagerly=False,
582
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()]
583
)
584
585
test_images = []
586
for sample in testing_dataset.take(5):
587
    #print(sample[0].shape)
588
    test_images.append([sample[0],sample[1]])
589
590
sanity_check = SanityCheck(test_images)
591
592
593
def schedule(epoch, lr): 
594
        return(lr*0.97)
595
596
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule, verbose=1)
597
598
599
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
600
                                                 mode='min',
601
                                                 factor=0.8,
602
                                                 patience=5,
603
                                                 min_lr=0.000001,
604
                                                 verbose=True,
605
                                                 min_delta=0.01,)
606
607
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
608
    'unet_seg_weights.{epoch:02d}-{val_sparse_categorical_accuracy:.4f}-{val_loss:.4f}.h5',
609
    save_weights_only=True,
610
    monitor='val_sparse_categorical_accuracy',
611
    mode='max',
612
    verbose=True
613
    )
614
615
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,
616
                                                     monitor='val_sparse_categorical_accuracy',
617
                                                     mode='max',
618
                                                     restore_best_weights=True,
619
                                                     verbose=True,
620
                                                     min_delta=0.001)
621
622
# setting the number of batches to iterate through each epoch to a value much
623
# lower than what it normaly would be so that we can actually see what is going
624
# on with the network, as well as have a meaningful early stopping.
625
626
627
# %% fit the network!
628
num_steps = 600
629
630
history = unet.fit(training_dataset,
631
                   epochs=100,
632
                   steps_per_epoch=num_steps,
633
                   validation_data=validation_dataset,
634
                   verbose=2,
635
                   callbacks=[checkpoint_cb,
636
                              early_stopping_cb,
637
                              lr_scheduler,])
638
# %%
639
640
641
642
# %%
643
# evaluate the network after loading the weights
644
unet.load_weights('unet_seg_weights.84-0.9163-0.0053.h5')
645
results = unet.evaluate(testing_dataset)
646
print(results)
647
# %%
648
# extracting loss vs epoch
649
loss = history.history['loss']
650
val_loss = history.history['val_loss']
651
acc = history.history['sparse_categorical_accuracy']
652
val_acc = history.history['val_sparse_categorical_accuracy']
653
654
# extracting precision vs epoch
655
656
epochs = range(len(loss))
657
658
figs, axes = plt.subplots(2,1)
659
660
# plotting loss and validation loss
661
axes[0].plot(epochs[1:],loss[1:])
662
axes[0].plot(epochs[1:],val_loss[1:])
663
axes[0].legend(['loss','val_loss'])
664
axes[0].set(xlabel='epochs',ylabel='crossentropy loss')
665
666
# plotting loss and validation loss
667
axes[1].plot(epochs[1:],acc[1:])
668
axes[1].plot(epochs[1:],val_acc[1:])
669
axes[1].legend(['acc','val_acc'])
670
axes[1].set(xlabel='epochs',ylabel='weighted accuracy')
671
672
673
# %% exploring the predictions to better understand what the network is doing. 
674
# This section is largely experimental, and should be treated as such. I have
675
# included it in this network file for the sake of documentation and 
676
# traceability, but it is not in the other network files for full image 
677
# segmentation and directory segmentation because, well, those are functional 
678
# and this is experimental.
679
680
681
# uncomment everything from here down to use this section
682
images = []
683
gt = []
684
predictions = []
685
# higher threshold means the network must be more confident.
686
threshold = 3
687
688
# taking out 15 of the next samples from the testing dataset and iterating 
689
# through them
690
for sample in testing_dataset.take(15):
691
    # make sure it is producing the correct dimensions
692
    print(sample[0].shape)
693
    # take the image and convert it back to RGB, store in list
694
    image = sample[0]
695
    image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB)
696
    images.append(image)
697
    # extract the ground truth and store in list
698
    ground_truth = sample[1]
699
    gt.append(ground_truth)
700
    # perform inference
701
    out = unet(sample[0],predict=True,threshold=threshold)
702
    predictions.append(out)
703
    # show the original input image
704
    plt.imshow(image)
705
    plt.show()
706
    # flatten the ground truth from one-hot encoded along the last axis, and 
707
    # show the resulting image
708
    squeezed_gt = ground_truth
709
    squeezed_prediction = tf.argmax(out,axis=-1)
710
    plt.imshow(squeezed_gt[0,:,:],vmin=0, vmax=5)
711
    # print the number of classes in this tile
712
    print(np.unique(squeezed_gt))
713
    plt.show()
714
    # show the flattened predictions
715
    plt.imshow(squeezed_prediction[0,:,:],vmin=0, vmax=5)
716
    print(np.unique(squeezed_prediction))
717
    plt.show()
718
719
# # %% 5, 6, 8
720
# # select one of the images cycled through above to investigate further
721
# image_to_investigate = 0
722
# threshold = 2
723
# # show the original image
724
# plt.imshow(images[image_to_investigate])
725
# plt.show()
726
727
# # show the ground truth for this tile
728
# squeezed_gt = gt[image_to_investigate]
729
# plt.imshow(squeezed_gt[0,:,:])
730
# # print the number of unique classes in the ground truth
731
# print(np.unique(squeezed_gt))
732
# plt.show()
733
#  # flatten the prediction and show the probability distribution
734
735
# out = predictions[image_to_investigate]
736
737
738
# # plt.hist(out[:,:,:,4].reshape(-1),alpha=0.5,label='neural')
739
# # plt.hist(out[:,:,:,3].reshape(-1),alpha=0.5,label='vascular')
740
# # plt.legend(["neural",'vascular'])
741
742
# out = predictions[image_to_investigate]
743
# squeezed_prediction = np.squeeze(tf.argmax(out,axis=-1))
744
745
# max_value_predictions = np.zeros(squeezed_prediction.shape)
746
747
# for idx in range(6):
748
#     current_slice = np.squeeze(out[:,:,:,idx])
749
#     current_indices = squeezed_prediction == idx
750
#     indices = tf.where(squeezed_prediction == idx)
751
#     values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
752
#     max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
753
754
# plt.imshow(max_value_predictions)
755
# plt.show()
756
757
# for idx in [3,4]:
758
#     mask = np.zeros(out.shape)
759
#     mask[:,:,:,idx] = max_value_predictions<threshold
760
#     indices = tf.where(mask)
761
#     values_updates = tf.boolean_mask(np.zeros(out.shape),mask).astype(tf.double)
762
763
#     out = tf.tensor_scatter_nd_update(out,indices,values_updates.astype(tf.float32))
764
765
# for idx in range(6):
766
#     current_slice = np.squeeze(out[:,:,:,idx])
767
#     current_indices = squeezed_prediction == idx
768
#     indices = tf.where(squeezed_prediction == idx)
769
#     values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
770
#     max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
771
# plt.imshow(max_value_predictions)
772
# plt.show()
773
774
775
# squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
776
# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
777
# # plt.show()
778
# # show the flattened image
779
# plt.imshow(squeezed_prediction[0,:,:])
780
# print(np.unique(squeezed_prediction))
781
# plt.show()
782
783
# squeezed_prediction = tf.argmax(out,axis=-1)
784
# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
785
# # plt.show()
786
# # show the flattened image
787
# plt.imshow(squeezed_prediction[0,:,:])
788
# print(np.unique(squeezed_prediction))
789
# plt.show()
790
791
# # %%
792
# image_to_investigate = 0
793
# threshold = 1
794
# y_pred = predictions[image_to_investigate]
795
796
797
# pred_shape = (1,1024,1024,6)
798
799
# squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
800
801
# max_value_predictions = tf.zeros((1024,1024))
802
803
# for idx in range(6):
804
805
#     current_slice = tf.squeeze(y_pred[:,:,:,idx])
806
#     current_indices = squeezed_prediction == idx
807
#     current_indices.set_shape((1024,1024))
808
#     indices = tf.where(squeezed_prediction == idx)
809
#     values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
810
#     max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
811
812
# for idx in [3,4]:
813
#     mask_list = []
814
#     for idx2 in range(6):
815
#         if idx2 == idx:
816
#             mid_mask = max_value_predictions<threshold
817
#             mask_list.append(mid_mask.astype(tf.float32))
818
#         else:
819
#             mask_list.append(tf.zeros((1024,1024)))
820
821
#     mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
822
823
#     indexes = tf.where(mask)
824
#     values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
825
826
#     y_pred = tf.tensor_scatter_nd_update(y_pred,indexes,values_updates.astype(tf.float32))
827
828
# squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
829
# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
830
# # plt.show()
831
# # show the flattened image
832
# plt.imshow(squeezed_prediction[0,:,:])
833
# print(np.unique(squeezed_prediction))
834
# plt.show()
835
836
# squeezed_prediction = tf.argmax(y_pred,axis=-1)
837
# # plt.imshow(predictions[image_to_investigate][0,:,:,3])
838
# # plt.show()
839
# # show the flattened image
840
# plt.imshow(squeezed_prediction[0,:,:])
841
# print(np.unique(squeezed_prediction))
842
# plt.show()
843
# # %%
844
845
# %%