Diff of /uNet_Subclassed_Large.py [000000] .. [3b7fea]

Switch to unified view

a b/uNet_Subclassed_Large.py
1
# %% importing packages
2
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow import keras
6
from tensorflow.keras import layers
7
from tensorflow.keras import mixed_precision
8
from tensorflow.python.ops.numpy_ops import np_config
9
np_config.enable_numpy_behavior()
10
from skimage import measure
11
import cv2 as cv
12
import os
13
import matplotlib.pyplot as plt
14
plt.rcParams['figure.figsize'] = [10, 15]
15
16
17
# %% Citations
18
#############################################################
19
#############################################################
20
# https://www.tensorflow.org/guide/keras/functional
21
# https://www.tensorflow.org/tutorials/customization/custom_layers
22
# https://keras.io/examples/keras_recipes/tfrecord/
23
# https://arxiv.org/abs/1505.04597
24
# https://www.tensorflow.org/guide/gpu
25
26
# Defining Functions
27
#############################################################
28
#############################################################
29
30
def parse_tf_elements(element):
31
    '''This function is the mapper function for retrieving examples from the
32
       tfrecord'''
33
34
    # create placeholders for all the features in each example
35
    data = {
36
        'height' : tf.io.FixedLenFeature([],tf.int64),
37
        'width' : tf.io.FixedLenFeature([],tf.int64),
38
        'raw_image' : tf.io.FixedLenFeature([],tf.string),
39
        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
40
        'bbox_x' : tf.io.VarLenFeature(tf.float32),
41
        'bbox_y' : tf.io.VarLenFeature(tf.float32),
42
        'bbox_height' : tf.io.VarLenFeature(tf.float32),
43
        'bbox_width' : tf.io.VarLenFeature(tf.float32)
44
    }
45
46
    # pull out the current example
47
    content = tf.io.parse_single_example(element, data)
48
49
    # pull out each feature from the example 
50
    height = content['height']
51
    width = content['width']
52
    raw_seg = content['raw_seg']
53
    raw_image = content['raw_image']
54
    bbox_x = content['bbox_x']
55
    bbox_y = content['bbox_y']
56
    bbox_height = content['bbox_height']
57
    bbox_width = content['bbox_width']
58
59
    # convert the images to uint8, and reshape them accordingly
60
    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
61
    image = tf.reshape(image,shape=[height,width,3])
62
    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)
63
    segmentation = tf.reshape(segmentation,shape=[height,width,1])
64
    one_hot_seg = tf.one_hot(tf.squeeze(segmentation),7,axis=-1)
65
66
    # there currently is a bug with returning the bbox, but isn't necessary
67
    # to fix for creating the initial uNet for segmentation exploration
68
    
69
    # bbox = [bbox_x,bbox_y,bbox_height,bbox_width]
70
71
    return(image,one_hot_seg)
72
73
#############################################################
74
75
class EncoderBlock(layers.Layer):
76
    '''This function returns an encoder block with two convolutional layers and 
77
       an option for returning both a max-pooled output with a stride and pool 
78
       size of (2,2) and the output of the second convolution for skip 
79
       connections implemented later in the network during the decoding 
80
       section. All padding is set to "same" for cleanliness.
81
       
82
       When initializing it receives the number of filters to be used in both
83
       of the convolutional layers as well as the kernel size and stride for 
84
       those same layers. It also receives the trainable variable for use with
85
       the batch normalization layers.'''
86
87
    def __init__(self,
88
                 filters,
89
                 kernel_size=(3,3),
90
                 strides=(1,1),
91
                 trainable=True,
92
                 name='encoder_block',
93
                 **kwargs):
94
95
        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
96
        # When initializing this object receives a trainable parameter for
97
        # freezing the convolutional layers. 
98
99
        # including the image normalization within the network for easier image
100
        # processing during inference
101
        self.image_normalization = layers.Normalization()
102
103
        # below creates the first of two convolutional layers
104
        self.conv1 = layers.Conv2D(filters=filters,
105
                      kernel_size=kernel_size,
106
                      strides=strides,
107
                      padding='same',
108
                      name='encoder_conv1',
109
                      trainable=trainable)
110
111
        # second of two convolutional layers
112
        self.conv2 = layers.Conv2D(filters=filters,
113
                      kernel_size=kernel_size,
114
                      strides=strides,
115
                      padding='same',
116
                      name='encoder_conv2',
117
                      trainable=trainable)
118
119
        # creates the max-pooling layer for downsampling the image.
120
        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
121
                                    strides=(2,2),
122
                                    padding='same',
123
                                    name='enc_pool')
124
125
        # ReLU layer for activations.
126
        self.ReLU = layers.ReLU()
127
        
128
        # both batch normalization layers for use with their corresponding
129
        # convolutional layers.
130
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
131
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
132
133
    def call(self,input,training=True,include_pool=True):
134
        
135
        # first conv of the encoder block
136
        x = self.image_normalization(input)
137
        x = self.conv1(x)
138
        x = self.batch_norm1(x,training=training)
139
        x = self.ReLU(x)
140
141
        # second conv of the encoder block
142
        x = self.conv2(x)
143
        x = self.batch_norm2(x,training=training)
144
        x = self.ReLU(x)
145
        
146
        # calculate and include the max pooling layer if include_pool is true.
147
        # This output is used for the skip connections later in the network.
148
        if include_pool:
149
            pooled_x = self.enc_pool(x)
150
            return(x,pooled_x)
151
152
        else:
153
            return(x)
154
155
156
#############################################################
157
158
class DecoderBlock(layers.Layer):
159
    '''This function returns a decoder block that when called receives both an
160
       input and a "skip connection". The input is passed to the 
161
       "up convolution" or transpose conv layer to double the dimensions before
162
       being concatenated with its associated skip connection from the encoder
163
       section of the network. All padding is set to "same" for cleanliness. 
164
       The decoder block also has an option for including an additional 
165
       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
166
       produces the logits for the one-hot encoded ground truth. 
167
       
168
       When initializing it receives the number of filters to be used in the
169
       up convolutional layer as well as the other two forward convolutions. 
170
       The received kernel_size and stride is used for the forward convolutions,
171
       with the up convolution kernel and stride set to be (2,2).'''
172
    def __init__(self,
173
                 filters,
174
                 trainable=True,
175
                 kernel_size=(3,3),
176
                 strides=(1,1),
177
                 name='DecoderBlock',
178
                 **kwargs):
179
180
        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
181
182
        # creating the up convolution layer
183
        self.up_conv = layers.Conv2DTranspose(filters=filters,
184
                                              kernel_size=(2,2),
185
                                              strides=(2,2),
186
                                              padding='same',
187
                                              name='decoder_upconv',
188
                                              trainable=trainable)
189
190
        # the first of two forward convolutional layers
191
        self.conv1 = layers.Conv2D(filters=filters,
192
                                   kernel_size=kernel_size,
193
                                   strides=strides,
194
                                   padding='same',
195
                                   name ='decoder_conv1',
196
                                   trainable=trainable)
197
198
        # second convolutional layer
199
        self.conv2 = layers.Conv2D(filters=filters,
200
                                   kernel_size=kernel_size,
201
                                   strides=strides,
202
                                   padding='same',
203
                                   name ='decoder_conv2',
204
                                   trainable=trainable)
205
206
        # this creates the output prediction logits layer.
207
        self.seg_out = layers.Conv2D(filters=7,
208
                        kernel_size=(1,1),
209
                        name='conv_feature_map')
210
211
        # ReLU for activation of all above layers
212
        self.ReLU = layers.ReLU()
213
        
214
        # the individual batch normalization layers for their respective 
215
        # convolutional layers.
216
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
217
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
218
219
220
    def call(self,input,skip_conn,training=True,segmentation=False):
221
        
222
        up = self.up_conv(input) # perform image up convolution
223
        # concatenate the input and the skip_conn along the features axis
224
        concatenated = layers.concatenate([up,skip_conn],axis=-1)
225
226
        # first convolution 
227
        x = self.conv1(concatenated)
228
        x = self.batch_norm1(x,training=training)
229
        x = self.ReLU(x)
230
231
        # second convolution
232
        x = self.conv2(x)
233
        x = self.batch_norm2(x,training=training)
234
        x = self.ReLU(x)
235
236
        # if segmentation is True, then run the segmentation (1,1) convolution
237
        # and use the Softmax to produce a probability distribution.
238
        if segmentation:
239
            seg = self.seg_out(x)
240
            # deliberately set as "float32" to ensure proper calculation if 
241
            # switching to mixed precision for efficiency
242
            prob = layers.Softmax(dtype='float32')(seg)
243
            return(prob)
244
245
        else:
246
            return(x)
247
248
#############################################################
249
250
class uNet(keras.Model):
251
    '''This is a sub-classed model that uses the encoder and decoder blocks
252
       defined above to create a custom unet. The differences from the original 
253
       paper include a variable filter scalar (filter_multiplier), batch 
254
       normalization between each convolutional layer and the associated ReLU 
255
       activation, as well as feature normalization implemented in the first 
256
       layer of the network.'''
257
    def __init__(self,filter_multiplier=2,**kwargs):
258
        super(uNet,self).__init__()
259
        
260
        # Defining encoder blocks
261
        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
262
                                           name='Enc1')
263
        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
264
                                           name='Enc2')
265
        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
266
                                           name='Enc3')
267
        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
268
                                           name='Enc4')
269
        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
270
                                           name='Enc5')
271
        self.encoder_block6 = EncoderBlock(filters=64*filter_multiplier,
272
                                           name='Enc6')
273
274
        # Defining decoder blocks. The names are in reverse order to make it 
275
        # (hopefully) easier to understand which skip connections are associated
276
        # with which decoder layers.
277
        self.decoder_block5 = DecoderBlock(filters=32*filter_multiplier,
278
                                           name='Dec5')
279
        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
280
                                           name='Dec4')
281
        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
282
                                           name='Dec3')
283
        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
284
                                           name='Dec2')
285
        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
286
                                           name='Dec1')
287
288
289
    def call(self,inputs,training):
290
291
        # encoder    
292
        enc1,enc1_pool = self.encoder_block1(input=inputs,training=training)
293
        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
294
        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
295
        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
296
        enc5,enc5_pool = self.encoder_block5(input=enc4_pool,training=training)
297
        enc6 = self.encoder_block6(input=enc5_pool,
298
                                   include_pool=False,
299
                                   training=training)
300
301
        # decoder
302
        dec5 = self.decoder_block5(input=enc6,skip_conn=enc5,training=training)
303
        dec4 = self.decoder_block4(input=dec5,skip_conn=enc4,training=training)
304
        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
305
        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
306
        seg_logits_out = self.decoder_block1(input=dec2,
307
                                             skip_conn=enc1,
308
                                             segmentation=True,
309
                                             training=training)
310
311
        return(seg_logits_out)
312
313
#############################################################
314
315
def load_dataset(file_names):
316
    '''Receives a list of file names from a folder that contains tfrecord files
317
       compiled previously. Takes these names and creates a tensorflow dataset
318
       from them.'''
319
320
    ignore_order = tf.data.Options()
321
    ignore_order.experimental_deterministic = False
322
    dataset = tf.data.TFRecordDataset(file_names)
323
324
    # you can shard the dataset if you like to reduce the size when necessary
325
    dataset = dataset.shard(num_shards=3,index=1)
326
    
327
    # order in the file names doesn't really matter, so ignoring it
328
    dataset = dataset.with_options(ignore_order)
329
330
    # mapping the dataset using the parse_tf_elements function defined earlier
331
    dataset = dataset.map(parse_tf_elements,num_parallel_calls=1)
332
    
333
    return(dataset)
334
335
#############################################################
336
337
def get_dataset(file_names,batch_size):
338
    '''Receives a list of file names of tfrecord shards from a dataset as well
339
       as a batch size for the dataset.'''
340
    
341
    # uses the load_dataset function to retrieve the files and put them into a 
342
    # dataset.
343
    dataset = load_dataset(file_names)
344
    
345
    # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free
346
    # to alter as fits your hardware.
347
    dataset = dataset.shuffle(1000)
348
349
    # adding the batch size to the dataset
350
    dataset = dataset.batch(batch_size=batch_size)
351
352
    return(dataset)
353
354
#############################################################
355
356
def weighted_cce_loss(y_true,y_pred):
357
    '''Yes, this function essentially does what the "fit" argument 
358
       "class_weight" does when training a network. I had to create this 
359
       separate custom loss function because aparently when using tfrecord files
360
       for reading your dataset a check is performed comparing the input, ground
361
       truth, and weights values to each other. However, a comparison between 
362
       the empty None that is passed during the build call of the model and the
363
       weight array/dictionary returns an error. Thus, here is a custom loss 
364
       function that applies a weighting to the different classes based on the 
365
       distribution of the classes within the entire dataset. For thoroughness'
366
       sake future iteration of the dataset will only base the weights on the 
367
       dataset used for training, not the whole dataset.'''
368
369
    # weights for each class, as background, connective, muscle, and vasculature
370
    weights = [0,10.52735078, 2.3808943, 2.44062288, 250.61600774,  8,  20]
371
    # create a weight for each of the images in the current batch (because the
372
    # weighting for categorical crossentropy needs one per input)
373
    for idx,weight in enumerate(weights):
374
        # making the input a numpy array and not an eager tensor to allow for
375
        # binary index masking.
376
        current_weights = np.asarray(tf.argmax(y_true,axis=-1)).copy().astype(
377
                                                                    np.float64)
378
        # create a mask for the current class that then becomes the value of the
379
        # weight. This is then passed to the loss function to apply to each
380
        # pixel.
381
        mask = current_weights==idx
382
        current_weights[mask] = weight
383
384
    cce = tf.keras.losses.CategoricalCrossentropy()
385
    cce_loss = cce(y_true,y_pred,current_weights)
386
387
    return(cce_loss)
388
    
389
#############################################################
390
#############################################################
391
# %% Setting up the GPU, and setting memory growth to true so that it is easier
392
# to see how much memory the training process is taking up exactly. This code is
393
# from a tensorflow tutorial. 
394
395
gpus = tf.config.list_physical_devices('GPU')
396
if gpus:
397
  try:
398
    for gpu in gpus:
399
      tf.config.experimental.set_memory_growth(gpu, True)
400
    logical_gpus = tf.config.list_logical_devices('GPU')
401
402
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
403
  except RuntimeError as e:
404
    print(e)
405
406
# use this to set mixed precision for higher efficiency later if you would like
407
# mixed_precision.set_global_policy('mixed_float16')
408
409
# %% setting up datasets and building model
410
411
# directory where the dataset shards are stored
412
shard_dataset_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards'
413
414
os.chdir(shard_dataset_directory)
415
416
# only get the file names that follow the shard naming convention
417
file_names = tf.io.gfile.glob(shard_dataset_directory + \
418
                              "/shard_*_of_*.tfrecords")
419
420
# first 70% of names go to the training dataset. Following 20% go to the val
421
# dataset, followed by last 10% go to the testing dataset.
422
val_split_idx = int(0.7*len(file_names))
423
test_split_idx = int(0.9*len(file_names))
424
425
# separate the file names out
426
train_files, val_files, test_files = file_names[:val_split_idx],\
427
                                     file_names[val_split_idx:test_split_idx],\
428
                                     file_names[test_split_idx:]
429
430
# create the datasets. Because of how batches are run for training, we set
431
# the dataset to repeat() because the batches and epochs are altered from 
432
# standard practice to fit on graphics cards and provide more meaningful and 
433
# frequent updates to the console.
434
training_dataset = get_dataset(train_files,batch_size=3)
435
training_dataset = training_dataset.repeat()
436
validation_dataset = get_dataset(val_files,batch_size = 2)
437
# testing has a batch size of 1 to facilitate visualization of predictions
438
testing_dataset = get_dataset(test_files,batch_size=1)
439
440
# explicitly puts the model on the GPU to show how large it is. 
441
gpus = tf.config.list_logical_devices('GPU')
442
with tf.device(gpus[0].name):
443
    # filter multiplier provided creates largest filter depth of 256 with a 
444
    # multiplier of 8. 
445
    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
446
    unet = uNet(filter_multiplier=32)
447
    # build with input image size of 512*512
448
    out = unet(sample_data)
449
    unet.summary()
450
# %%
451
# running network eagerly because it allows us to use convert a tensor to a
452
# numpy array to help with the weighted loss calculation.
453
unet.compile(
454
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.00008),
455
    loss=weighted_cce_loss,
456
    run_eagerly=True,
457
    metrics=[tf.keras.metrics.Precision(name='precision'),
458
                tf.keras.metrics.Recall(name='recall')]
459
)
460
461
# %%
462
463
# creating callbacks
464
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_recall',
465
                                                 mode='max',
466
                                                 factor=0.8,
467
                                                 patience=3,
468
                                                 min_lr=0.000001,
469
                                                 verbose=True)
470
471
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('unet_seg_subclassed.h5',
472
                                                   save_best_only=True,
473
                                                   save_weights_only=True,
474
                                                   monitor='val_precision',
475
                                                   mode='max',
476
                                                   verbose=True)
477
478
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=8,
479
                                                     monitor='val_recall',
480
                                                     mode='max',
481
                                                     restore_best_weights=True,
482
                                                     verbose=True)
483
484
# setting the number of batches to iterate through each epoch to a value much
485
# lower than what it normaly would be so that we can actually see what is going
486
# on with the network, as well as have a meaningful early stopping.
487
num_steps = 250
488
489
# fit the network!
490
history = unet.fit(training_dataset,
491
                   epochs=70,
492
                   steps_per_epoch=num_steps,
493
                   validation_data=validation_dataset,
494
                   callbacks=[checkpoint_cb,
495
                              early_stopping_cb,
496
                              reduce_lr])
497
# %%
498
499
500
501
# %%
502
# evaluate the network after loading the weights
503
unet.load_weights('./unet_seg_subclassed.h5')
504
results = unet.evaluate(testing_dataset)
505
506
# %%
507
# extracting loss vs epoch
508
loss = history.history['loss']
509
val_loss = history.history['val_loss']
510
# extracting precision vs epoch
511
precision = history.history['precision']
512
val_precision = history.history['val_precision']
513
# extracting recall vs epoch
514
recall = history.history['recall']
515
val_recall = history.history['val_recall']
516
517
epochs = range(len(loss))
518
519
figs, axes = plt.subplots(3,1)
520
521
# plotting loss and validation loss
522
axes[0].plot(epochs,loss)
523
axes[0].plot(epochs,val_loss)
524
axes[0].legend(['loss','val_loss'])
525
axes[0].set(xlabel='epochs',ylabel='crossentropy loss')
526
527
# plotting precision and validation precision
528
axes[1].plot(epochs,precision)
529
axes[1].plot(epochs,val_precision)
530
axes[1].legend(['precision','val_precision'])
531
axes[1].set(xlabel='epochs',ylabel='precision')
532
533
# plotting recall validation recall
534
axes[2].plot(epochs,recall)
535
axes[2].plot(epochs,val_recall)
536
axes[2].legend(['recall','val_recall'])
537
axes[2].set(xlabel='epochs',ylabel='recall')
538
539
540
541
# %% exploring the predictions to better understand what the network is doing
542
543
images = []
544
gt = []
545
predictions = []
546
547
# taking out 10 of the next samples from the testing dataset and iterating 
548
# through them
549
for sample in testing_dataset.take(10):
550
    # make sure it is producing the correct dimensions
551
    print(sample[0].shape)
552
    # take the image and convert it back to RGB, store in list
553
    image = sample[0]
554
    image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB)
555
    images.append(image)
556
    # extract the ground truth and store in list
557
    ground_truth = sample[1]
558
    gt.append(ground_truth)
559
    # perform inference
560
    out = unet.predict(sample[0])
561
    predictions.append(out)
562
    # show the original input image
563
    plt.imshow(image)
564
    plt.show()
565
    # flatten the ground truth from one-hot encoded along the last axis, and 
566
    # show the resulting image
567
    squeezed_gt = tf.argmax(ground_truth,axis=-1)
568
    squeezed_prediction = tf.argmax(out,axis=-1)
569
    plt.imshow(squeezed_gt[0,:,:],vmin=0, vmax=6)
570
    # print the number of classes in this tile
571
    print(np.unique(squeezed_gt))
572
    plt.show()
573
    # show the flattened predictions
574
    plt.imshow(squeezed_prediction[0,:,:],vmin=0, vmax=6)
575
    print(np.unique(squeezed_prediction))
576
    plt.show()
577
578
# %%
579
# select one of the images cycled through above to investigate further
580
image_to_investigate = 6
581
582
# show the original image
583
plt.imshow(images[image_to_investigate])
584
plt.show()
585
586
# show the ground truth for this tile
587
squeezed_gt = tf.argmax(gt[image_to_investigate],axis=-1)
588
plt.imshow(squeezed_gt[0,:,:])
589
# print the number of unique classes in the ground truth
590
print(np.unique(squeezed_gt))
591
plt.show()
592
 # flatten the prediction and show the probability distribution
593
squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
594
plt.imshow(predictions[image_to_investigate][0,:,:,3])
595
plt.show()
596
# show the flattened image
597
plt.imshow(squeezed_prediction[0,:,:])
598
print(np.unique(squeezed_prediction))
599
plt.show()
600
601
# %%