Diff of /uNet_Functional.py [000000] .. [3b7fea]

Switch to unified view

a b/uNet_Functional.py
1
# %% importing packages
2
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow import keras
6
from tensorflow.keras import layers
7
from tensorflow.keras import mixed_precision
8
from tensorflow.python.ops.numpy_ops import np_config
9
np_config.enable_numpy_behavior()
10
from skimage import measure
11
import cv2 as cv
12
import os
13
import tqdm
14
import matplotlib.pyplot as plt
15
import gc
16
17
18
# %% Citations
19
#############################################################
20
#############################################################
21
# https://www.tensorflow.org/guide/keras/functional
22
# https://www.tensorflow.org/tutorials/customization/custom_layers
23
# https://keras.io/examples/keras_recipes/tfrecord/
24
# https://arxiv.org/abs/1505.04597
25
# https://www.tensorflow.org/guide/gpu
26
27
# Defining Functions
28
#############################################################
29
#############################################################
30
31
def parse_tf_elements(element):
32
    '''This function is the mapper function for retrieving examples from the
33
       tfrecord'''
34
35
    # create placeholders for all the features in each example
36
    data = {
37
        'height' : tf.io.FixedLenFeature([],tf.int64),
38
        'width' : tf.io.FixedLenFeature([],tf.int64),
39
        'raw_image' : tf.io.FixedLenFeature([],tf.string),
40
        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
41
        'bbox_x' : tf.io.VarLenFeature(tf.float32),
42
        'bbox_y' : tf.io.VarLenFeature(tf.float32),
43
        'bbox_height' : tf.io.VarLenFeature(tf.float32),
44
        'bbox_width' : tf.io.VarLenFeature(tf.float32)
45
    }
46
47
    # pull out the current example
48
    content = tf.io.parse_single_example(element, data)
49
50
    # pull out each feature from the example 
51
    height = content['height']
52
    width = content['width']
53
    raw_seg = content['raw_seg']
54
    raw_image = content['raw_image']
55
    bbox_x = content['bbox_x']
56
    bbox_y = content['bbox_y']
57
    bbox_height = content['bbox_height']
58
    bbox_width = content['bbox_width']
59
60
    # convert the images to uint8, and reshape them accordingly
61
    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
62
    image = tf.reshape(image,shape=[height,width,3])
63
    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)
64
    segmentation = tf.reshape(segmentation,shape=[height,width,1])
65
    one_hot_seg = tf.one_hot(tf.squeeze(segmentation-1),4,axis=-1)
66
67
    # there currently is a bug with returning the bbox, but isn't necessary
68
    # to fix for creating the initial uNet for segmentation exploration
69
    
70
    # bbox = [bbox_x,bbox_y,bbox_height,bbox_width]
71
72
    return(image,one_hot_seg)
73
74
#############################################################
75
76
def load_dataset(file_names):
77
    '''Receives a list of file names from a folder that contains tfrecord files
78
       compiled previously. Takes these names and creates a tensorflow dataset
79
       from them.'''
80
81
    ignore_order = tf.data.Options()
82
    ignore_order.experimental_deterministic = False
83
    dataset = tf.data.TFRecordDataset(file_names)
84
85
    # you can shard the dataset if you like to reduce the size when necessary
86
    # dataset = dataset.shard(num_shards=2,index=1)
87
    
88
    # order in the file names doesn't really matter, so ignoring it
89
    dataset = dataset.with_options(ignore_order)
90
91
    # mapping the dataset using the parse_tf_elements function defined earlier
92
    dataset = dataset.map(parse_tf_elements,num_parallel_calls=1)
93
    
94
    return(dataset)
95
96
#############################################################
97
98
def get_dataset(file_names,batch_size):
99
    '''Receives a list of file names of tfrecord shards from a dataset as well
100
       as a batch size for the dataset.'''
101
    
102
    # uses the load_dataset function to retrieve the files and put them into a 
103
    # dataset.
104
    dataset = load_dataset(file_names)
105
    
106
    # creates a shuffle buffer of 1000. Number was arbitrarily chosen, feel free
107
    # to alter as fits your hardware.
108
    dataset = dataset.shuffle(1000)
109
110
    # adding the batch size to the dataset
111
    dataset = dataset.batch(batch_size=batch_size)
112
113
    return(dataset)
114
115
#############################################################
116
117
def weighted_cce_loss(y_true,y_pred):
118
    '''Yes, this function essentially does what the "fit" argument 
119
       "class_weight" does when training a network. I had to create this 
120
       separate custom loss function because aparently when using tfrecord files
121
       for reading your dataset a check is performed comparing the input, ground
122
       truth, and weights values to each other. However, a comparison between 
123
       the empty None that is passed during the build call of the model and the
124
       weight array/dictionary returns an error. Thus, here is a custom loss 
125
       function that applies a weighting to the different classes based on the 
126
       distribution of the classes within the entire dataset. For thoroughness'
127
       sake future iteration of the dataset will only base the weights on the 
128
       dataset used for training, not the whole dataset.'''
129
130
    # weights for each class, as background, connective, muscle, and vasculature
131
    weights = [28.78661087,3.60830475,1.63037567,14.44688883]
132
133
    # create a weight for each of the images in the current batch (because the
134
    # weighting for categorical crossentropy needs one per input)
135
    for idx,weight in enumerate(weights):
136
        # making the input a numpy array and not an eager tensor to allow for
137
        # binary index masking.
138
        current_weights = np.asarray(tf.argmax(y_true,axis=-1)).copy().astype(
139
                                                                    np.float64)
140
        # create a mask for the current class that then becomes the value of the
141
        # weight. This is then passed to the loss function to apply to each
142
        # pixel.
143
        mask = current_weights==idx
144
        current_weights[mask] = weight
145
146
    cce = tf.keras.losses.CategoricalCrossentropy()
147
    cce_loss = cce(y_true,y_pred,current_weights)
148
149
    return(cce_loss)
150
    
151
#############################################################
152
#############################################################
153
154
# %% Setting up the GPU, and setting memory growth to true so that it is easier
155
# to see how much memory the training process is taking up exactly. This code is
156
# from a tensorflow tutorial. 
157
158
gpus = tf.config.list_physical_devices('GPU')
159
if gpus:
160
  try:
161
    for gpu in gpus:
162
      tf.config.experimental.set_memory_growth(gpu, True)
163
    logical_gpus = tf.config.list_logical_devices('GPU')
164
165
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
166
  except RuntimeError as e:
167
    print(e)
168
169
# use this to set mixed precision for higher efficiency later if you would like
170
# mixed_precision.set_global_policy('mixed_float16')
171
172
173
# %% setting up datasets and building model
174
175
# directory where the dataset shards are stored
176
shard_dataset_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_ScaleFactor2'
177
178
os.chdir(shard_dataset_directory)
179
180
# only get the file names that follow the shard naming convention
181
file_names = tf.io.gfile.glob(shard_dataset_directory + \
182
                              "/shard_*_of_*.tfrecords")
183
184
# first 70% of names go to the training dataset. Following 20% go to the val
185
# dataset, followed by last 10% go to the testing dataset.
186
val_split_idx = int(0.7*len(file_names))
187
test_split_idx = int(0.9*len(file_names))
188
189
# separate the file names out
190
train_files, val_files, test_files = file_names[:val_split_idx],\
191
                                     file_names[val_split_idx:test_split_idx],\
192
                                     file_names[test_split_idx:]
193
194
# create the datasets. Because of how batches are run for training, we set
195
# the dataset to repeat() because the batches and epochs are altered from 
196
# standard practice to fit on graphics cards and provide more meaningful and 
197
# frequent updates to the console.
198
training_dataset = get_dataset(train_files,batch_size=15)
199
training_dataset = training_dataset.repeat()
200
validation_dataset = get_dataset(val_files,batch_size = 5)
201
# testing has a batch size of 1 to facilitate visualization of predictions
202
testing_dataset = get_dataset(test_files,batch_size=1)
203
204
# %% Putting together the network
205
206
# filter multiplier provided creates largest filter depth of 256 with a 
207
# multiplier of 8. 
208
filter_multiplier = 8
209
# encoder convolution parameters
210
enc_kernel = (3,3)
211
enc_strides = (1,1)
212
213
# encoder max-pooling parameters
214
enc_pool_size = (2,2)
215
enc_pool_strides = (2,2)
216
217
# setting the input size
218
net_input = keras.Input(shape=(512,512,3),name='original_image')
219
220
################## Encoder ##################
221
# encoder, block 1
222
223
# including the image normalization within the network for easier image
224
# processing during inference
225
normalized = layers.Normalization()(net_input)
226
227
enc1 = layers.Conv2D(filters=2*filter_multiplier,
228
                     kernel_size=enc_kernel,
229
                     strides=enc_strides,
230
                     padding='same',
231
                     name='enc1_conv1')(normalized)
232
233
enc1 = tf.keras.layers.BatchNormalization()(enc1)
234
enc1 = layers.ReLU()(enc1)
235
236
enc1 = layers.Conv2D(filters=2*filter_multiplier,
237
                     kernel_size=enc_kernel,
238
                     strides=enc_strides,
239
                     padding='same',
240
                     name='enc1_conv2')(enc1)
241
242
enc1 = tf.keras.layers.BatchNormalization()(enc1)
243
enc1 = layers.ReLU()(enc1)
244
245
enc1_pool = layers.MaxPooling2D(pool_size=enc_pool_size,
246
                                strides=enc_pool_strides,
247
                                padding='same',
248
                                name='enc1_pool')(enc1)
249
250
251
# encoder, block 2
252
enc2 = layers.Conv2D(filters=4*filter_multiplier,
253
                     kernel_size=enc_kernel,
254
                     strides=enc_strides,
255
                     padding='same',
256
                     name='enc2_conv1')(enc1_pool)
257
258
enc2 = tf.keras.layers.BatchNormalization()(enc2)
259
enc2 = layers.ReLU()(enc2)
260
261
enc2 = layers.Conv2D(filters=4*filter_multiplier,
262
                     kernel_size=enc_kernel,
263
                     strides=enc_strides,
264
                     padding='same',
265
                     name='enc2_conv2')(enc2)
266
267
enc2 = tf.keras.layers.BatchNormalization()(enc2)
268
enc2 = layers.ReLU()(enc2)
269
270
enc2_pool = layers.MaxPooling2D(pool_size=enc_pool_size,
271
                                strides=enc_pool_strides,
272
                                padding='same',
273
                                name='enc2_pool')(enc2)
274
275
276
# encoder, block 3
277
enc3 = layers.Conv2D(filters=8*filter_multiplier,
278
                     kernel_size=enc_kernel,
279
                     strides=enc_strides,
280
                     padding='same',
281
                     name='enc3_conv1')(enc2_pool)
282
283
enc3 = tf.keras.layers.BatchNormalization()(enc3)
284
enc3 = layers.ReLU()(enc3)
285
                     
286
enc3 = layers.Conv2D(filters=8*filter_multiplier,
287
                     kernel_size=enc_kernel,
288
                     strides=enc_strides,
289
                     padding='same',
290
                     name='enc3_conv2')(enc3)
291
292
enc3 = tf.keras.layers.BatchNormalization()(enc3)
293
enc3 = layers.ReLU()(enc3)
294
295
enc3_pool = layers.MaxPooling2D(pool_size=enc_pool_size,
296
                                strides=enc_pool_strides,
297
                                padding='same',
298
                                name='enc3_pool')(enc3)                         
299
300
# encoder, block 4
301
enc4 = layers.Conv2D(filters=16*filter_multiplier,
302
                     kernel_size=enc_kernel,
303
                     strides=enc_strides,
304
                     padding='same',
305
                     name='enc4_conv1')(enc3_pool)
306
307
enc4 = tf.keras.layers.BatchNormalization()(enc4)
308
enc4 = layers.ReLU()(enc4)
309
310
enc4 = layers.Conv2D(filters=16*filter_multiplier,
311
                     kernel_size=enc_kernel,
312
                     strides=enc_strides,
313
                     padding='same',
314
                     name='enc4_conv2')(enc4)
315
316
enc4 = tf.keras.layers.BatchNormalization()(enc4)
317
enc4 = layers.ReLU()(enc4)
318
319
enc4_pool = layers.MaxPooling2D(pool_size=enc_pool_size,
320
                                strides=enc_pool_strides,
321
                                padding='same',
322
                                name='enc4_pool')(enc4)     
323
324
325
# encoder, block 5
326
enc5 = layers.Conv2D(filters=32*filter_multiplier,
327
                     kernel_size=enc_kernel,
328
                     strides=enc_strides,
329
                     padding='same',
330
                     name='enc5_conv1')(enc4_pool)
331
332
enc5 = tf.keras.layers.BatchNormalization()(enc5)
333
enc5 = layers.ReLU()(enc5)
334
335
enc5 = layers.Conv2D(filters=32*filter_multiplier,
336
                     kernel_size=enc_kernel,
337
                     strides=enc_strides,
338
                     padding='same',
339
                     name='enc5_conv2')(enc5)
340
341
enc5 = tf.keras.layers.BatchNormalization()(enc5)
342
enc5 = layers.ReLU()(enc5)
343
344
################## Decoder ##################
345
346
# decoder upconv parameters
347
dec_upconv_kernel = (2,2)
348
dec_upconv_stride = (2,2)
349
350
# decoder forward convolution parameters
351
dec_conv_stride = (1,1)
352
dec_conv_kernel = (3,3)
353
354
# Decoder, block 4
355
dec4_up = layers.Conv2DTranspose(filters=16*filter_multiplier,
356
                              kernel_size=dec_upconv_kernel,
357
                              strides=dec_upconv_stride,
358
                              padding='same',
359
                              name='dec4_upconv')(enc5)
360
361
dec4_conc = layers.concatenate([dec4_up,enc4],axis=-1)
362
363
dec4 = layers.Conv2D(filters=16*filter_multiplier,
364
                     kernel_size=dec_conv_kernel,
365
                     strides=dec_conv_stride,
366
                     padding='same',
367
                     name='dec4_conv1')(dec4_conc)
368
369
dec4 = tf.keras.layers.BatchNormalization()(dec4)
370
dec4 = layers.ReLU()(dec4)
371
372
dec4 = layers.Conv2D(filters=16*filter_multiplier,
373
                     kernel_size=dec_conv_kernel,
374
                     strides=dec_conv_stride,
375
                     padding='same',
376
                     name='dec4_conv2')(dec4)
377
378
dec4 = tf.keras.layers.BatchNormalization()(dec4)
379
dec4 = layers.ReLU()(dec4)
380
381
382
# Decoder, block 3
383
dec3_up = layers.Conv2DTranspose(filters=8*filter_multiplier,
384
                              kernel_size=dec_upconv_kernel,
385
                              strides=dec_upconv_stride,
386
                              padding='same',
387
                              name='dec3_upconv')(dec4)
388
389
dec3_conc = layers.concatenate([dec3_up,enc3],axis=-1)
390
391
dec3 = layers.Conv2D(filters=8*filter_multiplier,
392
                     kernel_size=dec_conv_kernel,
393
                     strides=dec_conv_stride,
394
                     padding='same',
395
                     name='dec3_conv1')(dec3_conc)
396
397
dec3 = tf.keras.layers.BatchNormalization()(dec3)
398
dec3 = layers.ReLU()(dec3)
399
400
dec3 = layers.Conv2D(filters=8*filter_multiplier,
401
                     kernel_size=dec_conv_kernel,
402
                     strides=dec_conv_stride,
403
                     padding='same',
404
                     name='dec3_conv2')(dec3)
405
406
dec3 = tf.keras.layers.BatchNormalization()(dec3)
407
dec3 = layers.ReLU()(dec3)
408
409
410
# Decoder, block 2
411
dec2_up = layers.Conv2DTranspose(filters=4*filter_multiplier,
412
                              kernel_size=dec_upconv_kernel,
413
                              strides=dec_upconv_stride,
414
                              padding='same',
415
                              name='dec2_upconv')(dec3)
416
417
dec2_conc = layers.concatenate([dec2_up,enc2],axis=-1)
418
419
dec2 = layers.Conv2D(filters=4*filter_multiplier,
420
                     kernel_size=dec_conv_kernel,
421
                     strides=dec_conv_stride,
422
                     padding='same',
423
                     name='dec2_conv1')(dec2_conc)
424
425
dec2 = tf.keras.layers.BatchNormalization()(dec2)
426
dec2 = layers.ReLU()(dec2)
427
428
dec2 = layers.Conv2D(filters=4*filter_multiplier,
429
                     kernel_size=dec_conv_kernel,
430
                     strides=dec_conv_stride,
431
                     padding='same',
432
                     name='dec2_conv2')(dec2)
433
434
dec2 = tf.keras.layers.BatchNormalization()(dec2)
435
dec2 = layers.ReLU()(dec2)
436
437
438
# Decoder, block 1
439
dec1_up = layers.Conv2DTranspose(filters=2*filter_multiplier,
440
                              kernel_size=dec_upconv_kernel,
441
                              strides=dec_upconv_stride,
442
                              padding='same',
443
                              name='dec1_upconv')(dec2)
444
445
dec1_conc = layers.concatenate([dec1_up,enc1],axis=-1)
446
447
dec1 = layers.Conv2D(filters=2*filter_multiplier,
448
                     kernel_size=dec_conv_kernel,
449
                     strides=dec_conv_stride,
450
                     padding='same',
451
                     name='dec1_conv1')(dec1_conc)
452
453
dec1 = tf.keras.layers.BatchNormalization()(dec1)
454
dec1 = layers.ReLU()(dec1)
455
456
dec1 = layers.Conv2D(filters=2*filter_multiplier,
457
                     kernel_size=dec_conv_kernel,
458
                     strides=dec_conv_stride,
459
                     padding='same',
460
                     name='dec1_conv2')(dec1)
461
462
dec1 = tf.keras.layers.BatchNormalization()(dec1)
463
dec1 = layers.ReLU()(dec1)
464
465
conv_seg = layers.Conv2D(filters=4,
466
                         kernel_size=(1,1),
467
                         name='conv_feature_map')(dec1)
468
469
prob_dist = layers.Softmax(dtype='float32')(conv_seg)
470
471
unet = keras.Model(inputs=net_input,outputs=prob_dist,name='uNet')
472
473
unet.summary()
474
475
# %% setting up training
476
477
cce = tf.keras.losses.CategoricalCrossentropy()
478
479
# running network eagerly because it allows us to use convert a tensor to a
480
# numpy array to help with the weighted loss calculation.
481
unet.compile(
482
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
483
    loss=weighted_cce_loss,
484
    run_eagerly=True,
485
    metrics=[tf.keras.metrics.Precision(name='precision'),
486
                tf.keras.metrics.Recall(name='recall')]                
487
)
488
489
# %%
490
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_recall',
491
                                                 mode='max',
492
                                                 factor=0.8,
493
                                                 patience=3,
494
                                                 min_lr=0.00001,
495
                                                 verbose=True)
496
497
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('unet_seg_subclassed.h5',
498
                                                   save_best_only=True,
499
                                                   save_weights_only=True,
500
                                                   monitor='val_recall',
501
                                                   mode='max',
502
                                                   verbose=True)
503
504
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=8,
505
                                                     monitor='val_recall',
506
                                                     mode='max',
507
                                                     restore_best_weights=True,
508
                                                     verbose=True)
509
510
num_steps = 150
511
512
history = unet.fit(training_dataset,
513
                   epochs=20,
514
                   steps_per_epoch=num_steps,
515
                   validation_data=validation_dataset,
516
                   callbacks=[checkpoint_cb,
517
                              early_stopping_cb,
518
                              reduce_lr])
519
520
# %%
521
# evaluate the network after loading the weights
522
unet.load_weights('./unet_seg_functional.h5')
523
results = unet.evaluate(testing_dataset)
524
525
# %%
526
# extracting loss vs epoch
527
loss = history.history['loss']
528
val_loss = history.history['val_loss']
529
# extracting precision vs epoch
530
precision = history.history['precision']
531
val_precision = history.history['val_precision']
532
# extracting recall vs epoch
533
recall = history.history['recall']
534
val_recall = history.history['val_recall']
535
536
epochs = range(len(loss))
537
538
figs, axes = plt.subplots(3,1)
539
540
# plotting loss and validation loss
541
axes[0].plot(epochs,loss)
542
axes[0].plot(epochs,val_loss)
543
axes[0].legend(['loss','val_loss'])
544
axes[0].set(xlabel='epochs',ylabel='crossentropy loss')
545
546
# plotting precision and validation precision
547
axes[1].plot(epochs,precision)
548
axes[1].plot(epochs,val_precision)
549
axes[1].legend(['precision','val_precision'])
550
axes[1].set(xlabel='epochs',ylabel='precision')
551
552
# plotting recall validation recall
553
axes[2].plot(epochs,recall)
554
axes[2].plot(epochs,val_recall)
555
axes[2].legend(['recall','val_recall'])
556
axes[2].set(xlabel='epochs',ylabel='recall')
557
558
559
# %% exploring the predictions to better understand what the network is doing
560
561
images = []
562
gt = []
563
predictions = []
564
565
# taking out 10 of the next samples from the testing dataset and iterating 
566
# through them
567
for sample in testing_dataset.take(10):
568
    # make sure it is producing the correct dimensions
569
    print(sample[0].shape)
570
    # take the image and convert it back to RGB, store in list
571
    image = sample[0]
572
    image = cv.cvtColor(np.squeeze(np.asarray(image).copy()),cv.COLOR_BGR2RGB)
573
    images.append(image)
574
    # extract the ground truth and store in list
575
    ground_truth = sample[1]
576
    gt.append(ground_truth)
577
    # perform inference
578
    out = unet.predict(sample[0])
579
    predictions.append(out)
580
    # show the original input image
581
    plt.imshow(image)
582
    plt.show()
583
    # flatten the ground truth from one-hot encoded along the last axis, and 
584
    # show the resulting image
585
    squeezed_gt = tf.argmax(ground_truth,axis=-1)
586
    squeezed_prediction = tf.argmax(out,axis=-1)
587
    plt.imshow(squeezed_gt[0,:,:])
588
    # print the number of classes in this tile
589
    print(np.unique(squeezed_gt))
590
    plt.show()
591
    # show the flattened predictions
592
    plt.imshow(squeezed_prediction[0,:,:])
593
    print(np.unique(squeezed_prediction))
594
    plt.show()
595
596
# %%
597
# select one of the images cycled through above to investigate furtehr
598
image_to_investigate = 2
599
600
# show the original image
601
plt.imshow(images[image_to_investigate])
602
plt.show()
603
604
# show the ground truth for this tile
605
squeezed_gt = tf.argmax(gt[image_to_investigate],axis=-1)
606
plt.imshow(squeezed_gt[0,:,:])
607
# print the number of unique classes in the ground truth
608
print(np.unique(squeezed_gt))
609
plt.show()
610
 # flatten the prediction and show the probability distribution
611
squeezed_prediction = tf.argmax(predictions[image_to_investigate],axis=-1)
612
plt.imshow(predictions[image_to_investigate][0,:,:,3])
613
plt.show()
614
# show the flattened image
615
plt.imshow(squeezed_prediction[0,:,:])
616
print(np.unique(squeezed_prediction))
617
plt.show()