a b/uNet_FullImage_Segmentation_SCCE.py
1
# %% importing packages
2
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow import keras
6
from tensorflow.keras import layers
7
from tensorflow.keras import mixed_precision
8
from tensorflow.python.ops.numpy_ops import np_config
9
np_config.enable_numpy_behavior()
10
from skimage import measure
11
from skimage import morphology
12
from scipy import ndimage
13
import cv2 as cv
14
import os
15
import matplotlib.pyplot as plt
16
import tqdm
17
from natsort import natsorted
18
plt.rcParams['figure.figsize'] = [50, 150]
19
20
21
# %% Citations
22
#############################################################
23
#############################################################
24
25
26
# Defining Functions
27
#############################################################
28
#############################################################
29
30
#############################################################
31
32
class EncoderBlock(layers.Layer):
33
    '''This function returns an encoder block with two convolutional layers and 
34
       an option for returning both a max-pooled output with a stride and pool 
35
       size of (2,2) and the output of the second convolution for skip 
36
       connections implemented later in the network during the decoding 
37
       section. All padding is set to "same" for cleanliness.
38
       
39
       When initializing it receives the number of filters to be used in both
40
       of the convolutional layers as well as the kernel size and stride for 
41
       those same layers. It also receives the trainable variable for use with
42
       the batch normalization layers.'''
43
44
    def __init__(self,
45
                 filters,
46
                 kernel_size=(3,3),
47
                 strides=(1,1),
48
                 trainable=True,
49
                 name='encoder_block',
50
                 **kwargs):
51
52
        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
53
        # When initializing this object receives a trainable parameter for
54
        # freezing the convolutional layers. 
55
56
        # including the image normalization within the network for easier image
57
        # processing during inference
58
        self.image_normalization = layers.Rescaling(scale=1./255)
59
60
        # below creates the first of two convolutional layers
61
        self.conv1 = layers.Conv2D(filters=filters,
62
                      kernel_size=kernel_size,
63
                      strides=strides,
64
                      padding='same',
65
                      name='encoder_conv1',
66
                      trainable=trainable)
67
68
        # second of two convolutional layers
69
        self.conv2 = layers.Conv2D(filters=filters,
70
                      kernel_size=kernel_size,
71
                      strides=strides,
72
                      padding='same',
73
                      name='encoder_conv2',
74
                      trainable=trainable)
75
76
        # creates the max-pooling layer for downsampling the image.
77
        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
78
                                    strides=(2,2),
79
                                    padding='same',
80
                                    name='enc_pool')
81
82
        # ReLU layer for activations.
83
        self.ReLU = layers.ReLU()
84
        
85
        # both batch normalization layers for use with their corresponding
86
        # convolutional layers.
87
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
88
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
89
90
    def call(self,input,normalization=False,training=True,include_pool=True):
91
        
92
        # first conv of the encoder block
93
        if normalization:
94
            x = self.image_normalization(input)
95
            x = self.conv1(x)
96
        else:
97
            x = self.conv1(input)
98
99
        x = self.batch_norm1(x,training=training)
100
        x = self.ReLU(x)
101
102
        # second conv of the encoder block
103
        x = self.conv2(x)
104
        x = self.batch_norm2(x,training=training)
105
        x = self.ReLU(x)
106
        
107
        # calculate and include the max pooling layer if include_pool is true.
108
        # This output is used for the skip connections later in the network.
109
        if include_pool:
110
            pooled_x = self.enc_pool(x)
111
            return(x,pooled_x)
112
113
        else:
114
            return(x)
115
116
117
#############################################################
118
119
class DecoderBlock(layers.Layer):
120
    '''This function returns a decoder block that when called receives both an
121
       input and a "skip connection". The input is passed to the 
122
       "up convolution" or transpose conv layer to double the dimensions before
123
       being concatenated with its associated skip connection from the encoder
124
       section of the network. All padding is set to "same" for cleanliness. 
125
       The decoder block also has an option for including an additional 
126
       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
127
       produces the logits for the one-hot encoded ground truth. 
128
       
129
       When initializing it receives the number of filters to be used in the
130
       up convolutional layer as well as the other two forward convolutions. 
131
       The received kernel_size and stride is used for the forward convolutions,
132
       with the up convolution kernel and stride set to be (2,2).'''
133
    def __init__(self,
134
                 filters,
135
                 trainable=True,
136
                 kernel_size=(3,3),
137
                 strides=(1,1),
138
                 name='DecoderBlock',
139
                 **kwargs):
140
141
        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
142
143
        # creating the up convolution layer
144
        self.up_conv = layers.Conv2DTranspose(filters=filters,
145
                                              kernel_size=(2,2),
146
                                              strides=(2,2),
147
                                              padding='same',
148
                                              name='decoder_upconv',
149
                                              trainable=trainable)
150
151
        # the first of two forward convolutional layers
152
        self.conv1 = layers.Conv2D(filters=filters,
153
                                   kernel_size=kernel_size,
154
                                   strides=strides,
155
                                   padding='same',
156
                                   name ='decoder_conv1',
157
                                   trainable=trainable)
158
159
        # second convolutional layer
160
        self.conv2 = layers.Conv2D(filters=filters,
161
                                   kernel_size=kernel_size,
162
                                   strides=strides,
163
                                   padding='same',
164
                                   name ='decoder_conv2',
165
                                   trainable=trainable)
166
167
        # this creates the output prediction logits layer.
168
        self.seg_out = layers.Conv2D(filters=6,
169
                        kernel_size=(1,1),
170
                        name='conv_feature_map')
171
172
        # ReLU for activation of all above layers
173
        self.ReLU = layers.ReLU()
174
        
175
        # the individual batch normalization layers for their respective 
176
        # convolutional layers.
177
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
178
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
179
180
181
    def call(self,input,skip_conn,training=True,segmentation=False,prob_dist=True):
182
        
183
        up = self.up_conv(input) # perform image up convolution
184
        # concatenate the input and the skip_conn along the features axis
185
        concatenated = layers.concatenate([up,skip_conn],axis=-1)
186
187
        # first convolution 
188
        x = self.conv1(concatenated)
189
        x = self.batch_norm1(x,training=training)
190
        x = self.ReLU(x)
191
192
        # second convolution
193
        x = self.conv2(x)
194
        x = self.batch_norm2(x,training=training)
195
        x = self.ReLU(x)
196
197
        # if segmentation is True, then run the segmentation (1,1) convolution
198
        # and use the Softmax to produce a probability distribution.
199
        if segmentation:
200
            seg = self.seg_out(x)
201
            # deliberately set as "float32" to ensure proper calculation if 
202
            # switching to mixed precision for efficiency
203
            if prob_dist:
204
                seg = layers.Softmax(dtype='float32')(seg)
205
206
            return(seg)
207
208
        else:
209
            return(x)
210
211
#############################################################
212
213
class uNet(keras.Model):
214
    '''This is a sub-classed model that uses the encoder and decoder blocks
215
       defined above to create a custom unet. The differences from the original 
216
       paper include a variable filter scalar (filter_multiplier), batch 
217
       normalization between each convolutional layer and the associated ReLU 
218
       activation, as well as feature normalization implemented in the first 
219
       layer of the network.'''
220
    def __init__(self,filter_multiplier=2,**kwargs):
221
        super(uNet,self).__init__()
222
        
223
        # Defining encoder blocks
224
        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
225
                                           name='Enc1')
226
        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
227
                                           name='Enc2')
228
        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
229
                                           name='Enc3')
230
        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
231
                                           name='Enc4')
232
        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
233
                                           name='Enc5')
234
235
        # Defining decoder blocks. The names are in reverse order to make it 
236
        # (hopefully) easier to understand which skip connections are associated
237
        # with which decoder layers.
238
        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
239
                                           name='Dec4')
240
        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
241
                                           name='Dec3')
242
        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
243
                                           name='Dec2')
244
        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
245
                                           name='Dec1')
246
247
248
    def call(self,inputs,training,predict=False,threshold=3):
249
250
        # encoder    
251
        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
252
        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
253
        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
254
        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
255
        enc5 = self.encoder_block5(input=enc4_pool,
256
                                   include_pool=False,
257
                                   training=training)
258
259
        # enc4 = self.encoder_block4(input=enc3_pool,
260
        #                            include_pool=False,
261
        #                            training=training)
262
263
264
        # decoder
265
        dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
266
        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
267
        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
268
        prob_dist_out = self.decoder_block1(input=dec2,
269
                                            skip_conn=enc1,
270
                                            segmentation=True,
271
                                            training=training)
272
        if predict:
273
            seg_logits_out = self.decoder_block1(input=dec2,
274
                                                 skip_conn=enc1,
275
                                                 segmentation=True,
276
                                                 training=training,
277
                                                 prob_dist=False)
278
279
        # This prediction is included to allow one to seta threshold for the 
280
        # uncertainty, deemed an arbitrary value that corresponds to the 
281
        # maximum value of the logits predicted at a specific point in the 
282
        # image. It only includes predictions for the vascular and neural 
283
        # tissues if they are above the confidence threshold, if they are below
284
        # the threshold the predictions are defaulted to muscle, connective,
285
        # or background.
286
        
287
        if predict:
288
            # rename the value for consistency and write protection.
289
            y_pred = seg_logits_out
290
            pred_shape = (1,1024,1024,6)
291
            # Getting an image-sized preliminary segmentation prediction
292
            squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
293
294
            # initializing the variable used for storing the maximum logits at 
295
            # each pixel location.
296
            max_value_predictions = tf.zeros((1024,1024))
297
298
            # cycle through all the classes 
299
            for idx in range(6):
300
                
301
                # current class logits
302
                current_slice = tf.squeeze(y_pred[:,:,:,idx])
303
                # find the locations where this class is predicted
304
                current_indices = squeezed_prediction == idx
305
                # define the shape so that this function can run in graph mode
306
                # and not need eager execution.
307
                current_indices.set_shape((1024,1024))
308
                # Get the indices of where the idx class is predicted
309
                indices = tf.where(squeezed_prediction == idx)
310
                # get the output of boolean_mask to enable scatter update of the
311
                # tensor. This is required because tensors do not support 
312
                # mask indexing.
313
                values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
314
                # Place the maximum logit values at each point in an 
315
                # image-size matrix, indicating the confidence in the prediction
316
                # at each pixel. 
317
                max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
318
            
319
            for idx in [3,4]:
320
                mask_list = []
321
                for idx2 in range(6):
322
                    if idx2 == idx:
323
324
                        if idx2 == 4:
325
                            threshold = threshold - 2
326
327
                        mid_mask = max_value_predictions<threshold
328
                        mask_list.append(mid_mask.astype(tf.float32))
329
                    else:
330
                        mask_list.append(tf.zeros((1024,1024)))
331
332
                mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
333
334
                indexes = tf.where(mask)
335
                values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
336
337
                seg_logits_out = tf.tensor_scatter_nd_update(seg_logits_out,indexes,values_updates.astype(tf.float32))
338
                prob_dist_out = layers.Softmax(dtype='float32')(seg_logits_out)
339
            # print("updated logits!")
340
341
342
            
343
        return(prob_dist_out)
344
345
346
#############################################################
347
348
def get_image_blocks(image,tile_distance=512,tile_size=1024):
349
    '''Receives an image as well as a minimum distance between tiles. 
350
       Returns the name of the image processed, the image dimensions, and a list
351
       of tile centers evenly distributed across the tissue surface.'''
352
    image_dimensions = image.shape
353
354
    safe_mask = np.zeros([image_dimensions[0],image_dimensions[1]])
355
    safe_mask[int(tile_size/2):image_dimensions[0]-int(tile_size/2),
356
              int(tile_size/2):image_dimensions[1]-int(tile_size/2)] = 1
357
358
    grid_0 = np.arange(0,image_dimensions[0],tile_distance)
359
    grid_1 = np.arange(0,image_dimensions[1],tile_distance)
360
361
    
362
363
    center_indexes = []
364
365
    for grid0 in grid_0:
366
        for grid1 in grid_1:
367
            if safe_mask[grid0,grid1]:
368
                center_indexes.append([grid0,grid1])
369
370
    return([image_dimensions,center_indexes])
371
372
#############################################################
373
374
def get_reduced_tile_indexes(tile_center,returned_size=1024):
375
    start_0 = int(tile_center[0] - returned_size/2)
376
    end_0 = int(tile_center[0] + returned_size/2)
377
378
    start_1 = int(tile_center[1] - returned_size/2)
379
    end_1 = int(tile_center[1] + returned_size/2)
380
381
    return([start_0,end_0],[start_1,end_1])
382
383
#############################################################
384
385
def segment_tiles(unet,center_indexes,image,threshold=3,scaling_factor=1,tile_size=1024):
386
    
387
    m,n,z = image.shape
388
    segmentation = np.zeros((m,n))
389
390
    for idx in tqdm.tqdm(range(len(center_indexes))):
391
        center = center_indexes[idx]
392
        dim0, dim1 = get_reduced_tile_indexes(center,tile_size)
393
        sub_sectioned_tile = image[dim0[0]:dim0[1],dim1[0]:dim1[1]] 
394
395
        full_tile_dim0,full_tile_dim1,z = sub_sectioned_tile.shape
396
397
        color_tile = sub_sectioned_tile[:,:,0:3]
398
399
        if scaling_factor > 1:
400
            height = color_tile.shape[0]
401
            width = color_tile.shape[1]
402
403
            height2 = int(height/scaling_factor)
404
            width2 = int(width/scaling_factor)
405
            
406
            color_tile = cv.resize(color_tile,[height2,width2],cv.INTER_AREA)
407
408
        color_tile = color_tile[None,:,:,:]
409
410
        prediction = unet(color_tile,predict=True,threshold=threshold)
411
412
        prediction_tile = np.squeeze(np.asarray(tf.argmax(prediction,axis=-1)).astype(np.float32).copy())
413
414
        if scaling_factor > 1:
415
            prediction_tile = cv.resize(prediction_tile,[full_tile_dim0,full_tile_dim1],cv.INTER_NEAREST)
416
417
418
        dim0, dim1 = get_reduced_tile_indexes(center,returned_size=512)
419
420
        # fix this hard coding of the tile indexes for the prediction
421
        segmentation[dim0[0]:dim0[1],dim1[0]:dim1[1]] = prediction_tile[256:768,256:768]
422
423
    return(segmentation)
424
425
#############################################################
426
427
def double_check_produced_dataset(new_directory,image_idx=0):
428
    '''this function samples a random image from a given directory, crops off 
429
       the ground truth from the 4th layer, and displays the color image to 
430
       verify they work.'''
431
    os.chdir(new_directory)
432
    file_names = tf.io.gfile.glob('./*.png')
433
    file_names = natsorted(file_names)
434
    # pick a random image index number
435
    if image_idx == 0:
436
        image_idx = int(np.random.random()*len(file_names))
437
    else:
438
        pass
439
440
    print(image_idx)
441
    # reading specific file from the random index
442
    segmentation = cv.imread(file_names[image_idx],cv.IMREAD_UNCHANGED)
443
    # changing the color for the tile from BGR to RGB
444
    print(file_names[image_idx])
445
    # plotting the images next to each other
446
    plt.imshow(segmentation,vmin=0, vmax=6)
447
    print(np.unique(segmentation))
448
    plt.show()
449
450
#############################################################
451
#############################################################
452
# %%
453
full_image_directory = '/var/confocaldata/HumanNodal/HeartData/16/02/JPG/'
454
file_names = tf.io.gfile.glob(full_image_directory + '*.jpg')
455
file_names = natsorted(file_names)
456
# %%
457
tile_size = 1024
458
unet_directory =  '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_6'
459
os.chdir(unet_directory)
460
sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
461
unet = uNet(filter_multiplier=12)
462
out = unet(sample_data,training=False,predict=True,threshold=3)
463
unet.summary()
464
unet.load_weights('/var/confocaldata/HumanNodal/HeartData/Best Networks/unet_seg_weights.63-0.9172-0.0065.h5')
465
466
# %%
467
468
image = cv.imread(file_names[250],cv.IMREAD_UNCHANGED)
469
image = cv.copyMakeBorder(image,2000,2000,2000,2000,cv.BORDER_REPLICATE)
470
471
# %%
472
dimensions,center_indexes = get_image_blocks(image,
473
                                             tile_distance=512,
474
                                             tile_size=tile_size
475
                                             )
476
477
segmentation = segment_tiles(unet,
478
                             center_indexes,
479
                             image,
480
                             threshold=3,
481
                             scaling_factor=1,
482
                             tile_size=tile_size)
483
484
# %%
485
corrected_image = cv.cvtColor(image,cv.COLOR_BGR2RGB)
486
plt.imshow(corrected_image[:,:,0:3])
487
plt.show()
488
489
# %%
490
# plt.imshow(image[:,:,3])
491
# plt.show()
492
plt.imshow(segmentation)
493
plt.show()
494
495
# %%
496
497
double_check_produced_dataset('/var/confocaldata/HumanNodal/HeartData/10/02/uNet_Segmentations',
498
                              image_idx=0)
499
# %%