Switch to unified view

a b/uNet_FullImage_Segmentation.py
1
# %% importing packages
2
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow import keras
6
from tensorflow.keras import layers
7
from tensorflow.keras import mixed_precision
8
from tensorflow.python.ops.numpy_ops import np_config
9
np_config.enable_numpy_behavior()
10
from skimage import measure
11
from skimage import morphology
12
from scipy import ndimage
13
import cv2 as cv
14
import os
15
import matplotlib.pyplot as plt
16
import tqdm
17
from natsort import natsorted
18
plt.rcParams['figure.figsize'] = [50, 150]
19
20
21
# %% Citations
22
#############################################################
23
#############################################################
24
25
26
# Defining Functions
27
#############################################################
28
#############################################################
29
30
class EncoderBlock(layers.Layer):
31
    '''This function returns an encoder block with two convolutional layers and 
32
       an option for returning both a max-pooled output with a stride and pool 
33
       size of (2,2) and the output of the second convolution for skip 
34
       connections implemented later in the network during the decoding 
35
       section. All padding is set to "same" for cleanliness.
36
       
37
       When initializing it receives the number of filters to be used in both
38
       of the convolutional layers as well as the kernel size and stride for 
39
       those same layers. It also receives the trainable variable for use with
40
       the batch normalization layers.'''
41
42
    def __init__(self,
43
                 filters,
44
                 kernel_size=(3,3),
45
                 strides=(1,1),
46
                 trainable=True,
47
                 name='encoder_block',
48
                 **kwargs):
49
50
        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
51
        # When initializing this object receives a trainable parameter for
52
        # freezing the convolutional layers. 
53
54
        # including the image normalization within the network for easier image
55
        # processing during inference
56
        self.image_normalization = layers.Normalization()
57
58
        # below creates the first of two convolutional layers
59
        self.conv1 = layers.Conv2D(filters=filters,
60
                      kernel_size=kernel_size,
61
                      strides=strides,
62
                      padding='same',
63
                      name='encoder_conv1',
64
                      trainable=trainable)
65
66
        # second of two convolutional layers
67
        self.conv2 = layers.Conv2D(filters=filters,
68
                      kernel_size=kernel_size,
69
                      strides=strides,
70
                      padding='same',
71
                      name='encoder_conv2',
72
                      trainable=trainable)
73
74
        # creates the max-pooling layer for downsampling the image.
75
        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
76
                                    strides=(2,2),
77
                                    padding='same',
78
                                    name='enc_pool')
79
80
        # ReLU layer for activations.
81
        self.ReLU = layers.ReLU()
82
        
83
        # both batch normalization layers for use with their corresponding
84
        # convolutional layers.
85
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
86
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
87
88
    def call(self,input,training=True,include_pool=True):
89
        
90
        # first conv of the encoder block
91
        x = self.image_normalization(input)
92
        x = self.conv1(x)
93
        x = self.batch_norm1(x,training=training)
94
        x = self.ReLU(x)
95
96
        # second conv of the encoder block
97
        x = self.conv2(x)
98
        x = self.batch_norm2(x,training=training)
99
        x = self.ReLU(x)
100
        
101
        # calculate and include the max pooling layer if include_pool is true.
102
        # This output is used for the skip connections later in the network.
103
        if include_pool:
104
            pooled_x = self.enc_pool(x)
105
            return(x,pooled_x)
106
107
        else:
108
            return(x)
109
110
111
#############################################################
112
113
class DecoderBlock(layers.Layer):
114
    '''This function returns a decoder block that when called receives both an
115
       input and a "skip connection". The input is passed to the 
116
       "up convolution" or transpose conv layer to double the dimensions before
117
       being concatenated with its associated skip connection from the encoder
118
       section of the network. All padding is set to "same" for cleanliness. 
119
       The decoder block also has an option for including an additional 
120
       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
121
       produces the logits for the one-hot encoded ground truth. 
122
       
123
       When initializing it receives the number of filters to be used in the
124
       up convolutional layer as well as the other two forward convolutions. 
125
       The received kernel_size and stride is used for the forward convolutions,
126
       with the up convolution kernel and stride set to be (2,2).'''
127
    def __init__(self,
128
                 filters,
129
                 trainable=True,
130
                 kernel_size=(3,3),
131
                 strides=(1,1),
132
                 name='DecoderBlock',
133
                 **kwargs):
134
135
        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
136
137
        # creating the up convolution layer
138
        self.up_conv = layers.Conv2DTranspose(filters=filters,
139
                                              kernel_size=(2,2),
140
                                              strides=(2,2),
141
                                              padding='same',
142
                                              name='decoder_upconv',
143
                                              trainable=trainable)
144
145
        # the first of two forward convolutional layers
146
        self.conv1 = layers.Conv2D(filters=filters,
147
                                   kernel_size=kernel_size,
148
                                   strides=strides,
149
                                   padding='same',
150
                                   name ='decoder_conv1',
151
                                   trainable=trainable)
152
153
        # second convolutional layer
154
        self.conv2 = layers.Conv2D(filters=filters,
155
                                   kernel_size=kernel_size,
156
                                   strides=strides,
157
                                   padding='same',
158
                                   name ='decoder_conv2',
159
                                   trainable=trainable)
160
161
        # this creates the output prediction logits layer.
162
        self.seg_out = layers.Conv2D(filters=7,
163
                        kernel_size=(1,1),
164
                        name='conv_feature_map')
165
166
        # ReLU for activation of all above layers
167
        self.ReLU = layers.ReLU()
168
        
169
        # the individual batch normalization layers for their respective 
170
        # convolutional layers.
171
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
172
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
173
174
175
    def call(self,input,skip_conn,training=True,segmentation=False):
176
        
177
        up = self.up_conv(input) # perform image up convolution
178
        # concatenate the input and the skip_conn along the features axis
179
        concatenated = layers.concatenate([up,skip_conn],axis=-1)
180
181
        # first convolution 
182
        x = self.conv1(concatenated)
183
        x = self.batch_norm1(x,training=training)
184
        x = self.ReLU(x)
185
186
        # second convolution
187
        x = self.conv2(x)
188
        x = self.batch_norm2(x,training=training)
189
        x = self.ReLU(x)
190
191
        # if segmentation is True, then run the segmentation (1,1) convolution
192
        # and use the Softmax to produce a probability distribution.
193
        if segmentation:
194
            seg = self.seg_out(x)
195
            # deliberately set as "float32" to ensure proper calculation if 
196
            # switching to mixed precision for efficiency
197
            prob = layers.Softmax(dtype='float32')(seg)
198
            return(prob)
199
200
        else:
201
            return(x)
202
203
204
#############################################################
205
206
class uNet(keras.Model):
207
    '''This is a sub-classed model that uses the encoder and decoder blocks
208
       defined above to create a custom unet. The differences from the original 
209
       paper include a variable filter scalar (filter_multiplier), batch 
210
       normalization between each convolutional layer and the associated ReLU 
211
       activation, as well as feature normalization implemented in the first 
212
       layer of the network.'''
213
    def __init__(self,filter_multiplier=2,**kwargs):
214
        super(uNet,self).__init__()
215
        
216
        # Defining encoder blocks
217
        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
218
                                           name='Enc1')
219
        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
220
                                           name='Enc2')
221
        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
222
                                           name='Enc3')
223
        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
224
                                           name='Enc4')
225
        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
226
                                           name='Enc5')
227
228
        # Defining decoder blocks. The names are in reverse order to make it 
229
        # (hopefully) easier to understand which skip connections are associated
230
        # with which decoder layers.
231
        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
232
                                           name='Dec4')
233
        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
234
                                           name='Dec3')
235
        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
236
                                           name='Dec2')
237
        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
238
                                           name='Dec1')
239
240
241
    def call(self,inputs,training):
242
243
        # encoder    
244
        enc1,enc1_pool = self.encoder_block1(input=inputs,training=training)
245
        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
246
        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
247
        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
248
        enc5 = self.encoder_block5(input=enc4_pool,
249
                                   include_pool=False,
250
                                   training=training)
251
252
        # decoder
253
        dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
254
        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
255
        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
256
        seg_logits_out = self.decoder_block1(input=dec2,
257
                                             skip_conn=enc1,
258
                                             segmentation=True,
259
                                             training=training)
260
261
        return(seg_logits_out)
262
263
#############################################################
264
265
def get_image_blocks(image,tile_distance=512,tile_size=1024):
266
    '''Receives an image as well as a minimum distance between tiles. 
267
       Returns the name of the image processed, the image dimensions, and a list
268
       of tile centers evenly distributed across the tissue surface.'''
269
270
    tissue_outline = image[:,:,3] != 0
271
    tissue_outline = ndimage.binary_fill_holes(tissue_outline)
272
    image_dimensions = tissue_outline.shape
273
274
    safe_mask = np.zeros(image_dimensions)
275
    safe_mask[int(tile_size/2):image_dimensions[0]-int(tile_size/2),
276
              int(tile_size/2):image_dimensions[1]-int(tile_size/2)] = 1
277
278
    grid_0 = np.arange(0,image_dimensions[0],tile_distance)
279
    grid_1 = np.arange(0,image_dimensions[1],tile_distance)
280
281
    
282
283
    center_indexes = []
284
285
    for grid0 in grid_0:
286
        for grid1 in grid_1:
287
            if safe_mask[grid0,grid1]:
288
                center_indexes.append([grid0,grid1])
289
                
290
    # for y,x, in center_indexes:
291
    #     plt.plot(x,y,marker='o',color='red',markersize=25)
292
    # plt.imshow(tissue_outline)
293
294
    # plt.show()
295
296
    return([image_dimensions,center_indexes])
297
298
#############################################################
299
300
def get_reduced_tile_indexes(tile_center,returned_size=512):
301
    start_0 = int(tile_center[0] - returned_size/2)
302
    end_0 = int(tile_center[0] + returned_size/2)
303
304
    start_1 = int(tile_center[1] - returned_size/2)
305
    end_1 = int(tile_center[1] + returned_size/2)
306
307
    return([start_0,end_0],[start_1,end_1])
308
309
#############################################################
310
311
def segment_tiles(unet,center_indexes,image,scaling_factor=1,tile_size=4096):
312
    
313
    m,n,z = image.shape
314
    segmentation = np.zeros((m,n))
315
316
    for idx in tqdm.tqdm(range(len(center_indexes))):
317
        center = center_indexes[idx]
318
        dim0, dim1 = get_reduced_tile_indexes(center,tile_size)
319
        sub_sectioned_tile = image[dim0[0]:dim0[1],dim1[0]:dim1[1]] 
320
321
        full_tile_dim0,full_tile_dim1,z = sub_sectioned_tile.shape
322
323
        color_tile = sub_sectioned_tile[:,:,0:3]
324
        seg_tile = sub_sectioned_tile[:,:,3]
325
326
        if scaling_factor > 1:
327
            height = color_tile.shape[0]
328
            width = color_tile.shape[1]
329
330
            height2 = int(height/scaling_factor)
331
            width2 = int(width/scaling_factor)
332
            
333
            color_tile = cv.resize(color_tile,[height2,width2],cv.INTER_AREA)
334
335
        if scaling_factor > 1:
336
            height = seg_tile.shape[0]
337
            width = seg_tile.shape[1]
338
339
            height2 = int(height/scaling_factor)
340
            width2 = int(width/scaling_factor)
341
            
342
            seg_tile = cv.resize(seg_tile,[height2,width2],cv.INTER_LINEAR)
343
344
        color_tile = color_tile[None,:,:,:]
345
346
        prediction = unet.predict(color_tile,verbose=0)
347
348
        prediction_tile = np.squeeze(np.asarray(tf.argmax(prediction,axis=-1)).astype(np.float32).copy())
349
350
        if scaling_factor > 1:
351
            prediction_tile = cv.resize(prediction_tile,[full_tile_dim0,full_tile_dim1],cv.INTER_LINEAR)
352
353
354
        dim0, dim1 = get_reduced_tile_indexes(center,returned_size=512)
355
356
        # fix this hard coding of the tile indexes for the prediction
357
        segmentation[dim0[0]:dim0[1],dim1[0]:dim1[1]] = prediction_tile[256:768,256:768]
358
359
    return(segmentation)
360
361
#############################################################
362
363
def double_check_produced_dataset(new_directory,image_idx=0):
364
    '''this function samples a random image from a given directory, crops off 
365
       the ground truth from the 4th layer, and displays the color image to 
366
       verify they work.'''
367
    os.chdir(new_directory)
368
    file_names = tf.io.gfile.glob('./*.png')
369
    file_names = natsorted(file_names)
370
    # pick a random image index number
371
    if image_idx == 0:
372
        image_idx = int(np.random.random()*len(file_names))
373
    else:
374
        pass
375
376
    print(image_idx)
377
    # reading specific file from the random index
378
    segmentation = cv.imread(file_names[image_idx],cv.IMREAD_UNCHANGED)
379
    # changing the color for the tile from BGR to RGB
380
    print(file_names[image_idx])
381
    # plotting the images next to each other
382
    plt.imshow(segmentation,vmin=0, vmax=6)
383
    print(np.unique(segmentation))
384
    plt.show()
385
386
#############################################################
387
#############################################################
388
# %%
389
full_image_directory = '/media/briancottle/Samsung_T5/ML_Dataset_5/'
390
file_names = tf.io.gfile.glob(full_image_directory + '*.png')
391
392
# %%
393
tile_size = 1024
394
unet_directory =  '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_4'
395
os.chdir(unet_directory)
396
sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
397
unet = uNet(filter_multiplier=12)
398
out = unet(sample_data)
399
unet.summary()
400
unet.load_weights('./unet_seg_weights.49-0.52-0.94-0.92.h5')
401
402
# %%
403
404
image = cv.imread(file_names[420],cv.IMREAD_UNCHANGED)
405
image = cv.copyMakeBorder(image,2000,2000,2000,2000,cv.BORDER_REPLICATE)
406
407
# %%
408
dimensions,center_indexes = get_image_blocks(image,
409
                                             tile_distance=512,
410
                                             tile_size=tile_size
411
                                             )
412
413
segmentation = segment_tiles(unet,
414
                             center_indexes,
415
                             image,
416
                             scaling_factor=1,
417
                             tile_size=tile_size)
418
419
# %%
420
corrected_image = cv.cvtColor(image,cv.COLOR_BGR2RGB)
421
plt.imshow(corrected_image[:,:,0:3])
422
plt.show()
423
plt.imshow(image[:,:,3])
424
plt.show()
425
plt.imshow(segmentation)
426
plt.show()
427
428
# %%
429
430
double_check_produced_dataset('/var/confocaldata/HumanNodal/HeartData/10/02/uNet_Segmentations',
431
                              image_idx=0)
432
# %%