Diff of /directory_segmentation.py [000000] .. [3b7fea]

Switch to unified view

a b/directory_segmentation.py
1
#!/usr/bin/env python3
2
"""
3
Author : briancottle 
4
Date   : 2022-12-14
5
Purpose: Segment an entire directory of histological .JPG files using the 
6
user provided uNet weights generated and saved by the uNet_Subclassed_SCCE.py
7
file. 
8
"""
9
10
import argparse
11
from typing import NamedTuple, TextIO
12
import numpy as np
13
import tensorflow as tf
14
from tensorflow import keras
15
from tensorflow.keras import layers
16
from tensorflow.keras import mixed_precision
17
from tensorflow.python.ops.numpy_ops import np_config
18
np_config.enable_numpy_behavior()
19
from skimage import measure
20
from skimage import morphology
21
from scipy import ndimage
22
import cv2 as cv
23
import os
24
import matplotlib.pyplot as plt
25
import tqdm
26
from natsort import natsorted
27
28
# --------------------------------------------------
29
30
class Args(NamedTuple):
31
    """ Command-line arguments """
32
    uNet_weights: str
33
    jpg_directory: str
34
    heart_id: str
35
    GPU_id: str
36
    threshold: int
37
38
39
# --------------------------------------------------
40
41
def get_args() -> Args:
42
    """ Get command-line arguments """
43
44
    parser = argparse.ArgumentParser(
45
        description='Providing command line arguments.',
46
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
47
48
    parser.add_argument('-u',
49
                        '--uNet_weights',
50
                        type=str,
51
                        metavar='uNet',
52
                        help='Path to an .h5 file created using the '
53
                        'architecture established in uNet_Subclassed_SCCE.py'
54
                        ' file.')
55
56
    parser.add_argument('-d',
57
                        '--jpg_directory',
58
                        help='The directory containing the .jpg files for '
59
                        'segmentation',
60
                        metavar='Dir',
61
                        type=str,
62
                        default='')
63
64
    parser.add_argument('-i',
65
                        '--heart_id',
66
                        help='The two digit heart ID for the data being '
67
                        'segmented',
68
                        metavar='HID',
69
                        type=str,
70
                        default='')
71
72
    parser.add_argument('-g',
73
                        '--GPU_id',
74
                        help='the GPU number for the session to run on',
75
                        metavar='GPU',
76
                        type=str,
77
                        default='')
78
79
    parser.add_argument('-t',
80
                        '--threshold',
81
                        help='what arbitrary threshold to use for confidence segmenting',
82
                        metavar='int',
83
                        type=int,
84
                        default=3)
85
86
87
88
    args = parser.parse_args()
89
90
    return Args(args.uNet_weights,
91
                args.jpg_directory, 
92
                args.heart_id,
93
                args.GPU_id,
94
                args.threshold,
95
                )
96
97
98
# --------------------------------------------------
99
# --------------------------------------------------
100
101
class EncoderBlock(layers.Layer):
102
    '''This function returns an encoder block with two convolutional layers and 
103
       an option for returning both a max-pooled output with a stride and pool 
104
       size of (2,2) and the output of the second convolution for skip 
105
       connections implemented later in the network during the decoding 
106
       section. All padding is set to "same" for cleanliness.
107
       
108
       When initializing it receives the number of filters to be used in both
109
       of the convolutional layers as well as the kernel size and stride for 
110
       those same layers. It also receives the trainable variable for use with
111
       the batch normalization layers.'''
112
113
    def __init__(self,
114
                 filters,
115
                 kernel_size=(3,3),
116
                 strides=(1,1),
117
                 trainable=True,
118
                 name='encoder_block',
119
                 **kwargs):
120
121
        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
122
        # When initializing this object receives a trainable parameter for
123
        # freezing the convolutional layers. 
124
125
        # including the image normalization within the network for easier image
126
        # processing during inference
127
        self.image_normalization = layers.Rescaling(scale=1./255)
128
129
        # below creates the first of two convolutional layers
130
        self.conv1 = layers.Conv2D(filters=filters,
131
                      kernel_size=kernel_size,
132
                      strides=strides,
133
                      padding='same',
134
                      name='encoder_conv1',
135
                      trainable=trainable)
136
137
        # second of two convolutional layers
138
        self.conv2 = layers.Conv2D(filters=filters,
139
                      kernel_size=kernel_size,
140
                      strides=strides,
141
                      padding='same',
142
                      name='encoder_conv2',
143
                      trainable=trainable)
144
145
        # creates the max-pooling layer for downsampling the image.
146
        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
147
                                    strides=(2,2),
148
                                    padding='same',
149
                                    name='enc_pool')
150
151
        # ReLU layer for activations.
152
        self.ReLU = layers.ReLU()
153
        
154
        # both batch normalization layers for use with their corresponding
155
        # convolutional layers.
156
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
157
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
158
159
    def call(self,input,normalization=False,training=True,include_pool=True):
160
        
161
        # first conv of the encoder block
162
        if normalization:
163
            x = self.image_normalization(input)
164
            x = self.conv1(x)
165
        else:
166
            x = self.conv1(input)
167
168
        x = self.batch_norm1(x,training=training)
169
        x = self.ReLU(x)
170
171
        # second conv of the encoder block
172
        x = self.conv2(x)
173
        x = self.batch_norm2(x,training=training)
174
        x = self.ReLU(x)
175
        
176
        # calculate and include the max pooling layer if include_pool is true.
177
        # This output is used for the skip connections later in the network.
178
        if include_pool:
179
            pooled_x = self.enc_pool(x)
180
            return(x,pooled_x)
181
182
        else:
183
            return(x)
184
185
186
# --------------------------------------------------
187
188
class DecoderBlock(layers.Layer):
189
    '''This function returns a decoder block that when called receives both an
190
       input and a "skip connection". The input is passed to the 
191
       "up convolution" or transpose conv layer to double the dimensions before
192
       being concatenated with its associated skip connection from the encoder
193
       section of the network. All padding is set to "same" for cleanliness. 
194
       The decoder block also has an option for including an additional 
195
       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
196
       produces the logits for the one-hot encoded ground truth. 
197
       
198
       When initializing it receives the number of filters to be used in the
199
       up convolutional layer as well as the other two forward convolutions. 
200
       The received kernel_size and stride is used for the forward convolutions,
201
       with the up convolution kernel and stride set to be (2,2).'''
202
    def __init__(self,
203
                 filters,
204
                 trainable=True,
205
                 kernel_size=(3,3),
206
                 strides=(1,1),
207
                 name='DecoderBlock',
208
                 **kwargs):
209
210
        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
211
212
        # creating the up convolution layer
213
        self.up_conv = layers.Conv2DTranspose(filters=filters,
214
                                              kernel_size=(2,2),
215
                                              strides=(2,2),
216
                                              padding='same',
217
                                              name='decoder_upconv',
218
                                              trainable=trainable)
219
220
        # the first of two forward convolutional layers
221
        self.conv1 = layers.Conv2D(filters=filters,
222
                                   kernel_size=kernel_size,
223
                                   strides=strides,
224
                                   padding='same',
225
                                   name ='decoder_conv1',
226
                                   trainable=trainable)
227
228
        # second convolutional layer
229
        self.conv2 = layers.Conv2D(filters=filters,
230
                                   kernel_size=kernel_size,
231
                                   strides=strides,
232
                                   padding='same',
233
                                   name ='decoder_conv2',
234
                                   trainable=trainable)
235
236
        # this creates the output prediction logits layer.
237
        self.seg_out = layers.Conv2D(filters=6,
238
                        kernel_size=(1,1),
239
                        name='conv_feature_map')
240
241
        # ReLU for activation of all above layers
242
        self.ReLU = layers.ReLU()
243
        
244
        # the individual batch normalization layers for their respective 
245
        # convolutional layers.
246
        self.batch_norm1 = tf.keras.layers.BatchNormalization()
247
        self.batch_norm2 = tf.keras.layers.BatchNormalization()
248
249
250
    def call(self,input,skip_conn,training=True,segmentation=False,prob_dist=True):
251
        
252
        up = self.up_conv(input) # perform image up convolution
253
        # concatenate the input and the skip_conn along the features axis
254
        concatenated = layers.concatenate([up,skip_conn],axis=-1)
255
256
        # first convolution 
257
        x = self.conv1(concatenated)
258
        x = self.batch_norm1(x,training=training)
259
        x = self.ReLU(x)
260
261
        # second convolution
262
        x = self.conv2(x)
263
        x = self.batch_norm2(x,training=training)
264
        x = self.ReLU(x)
265
266
        # if segmentation is True, then run the segmentation (1,1) convolution
267
        # and use the Softmax to produce a probability distribution.
268
        if segmentation:
269
            seg = self.seg_out(x)
270
            # deliberately set as "float32" to ensure proper calculation if 
271
            # switching to mixed precision for efficiency
272
            if prob_dist:
273
                seg = layers.Softmax(dtype='float32')(seg)
274
275
            return(seg)
276
277
        else:
278
            return(x)
279
280
# --------------------------------------------------
281
282
class uNet(keras.Model):
283
    '''This is a sub-classed model that uses the encoder and decoder blocks
284
       defined above to create a custom unet. The differences from the original 
285
       paper include a variable filter scalar (filter_multiplier), batch 
286
       normalization between each convolutional layer and the associated ReLU 
287
       activation, as well as feature normalization implemented in the first 
288
       layer of the network.'''
289
    def __init__(self,filter_multiplier=2,**kwargs):
290
        super(uNet,self).__init__()
291
        
292
        # Defining encoder blocks
293
        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
294
                                           name='Enc1')
295
        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
296
                                           name='Enc2')
297
        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
298
                                           name='Enc3')
299
        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
300
                                           name='Enc4')
301
        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
302
                                           name='Enc5')
303
304
        # Defining decoder blocks. The names are in reverse order to make it 
305
        # (hopefully) easier to understand which skip connections are associated
306
        # with which decoder layers.
307
        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
308
                                           name='Dec4')
309
        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
310
                                           name='Dec3')
311
        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
312
                                           name='Dec2')
313
        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
314
                                           name='Dec1')
315
316
317
    def call(self,inputs,training,predict=False,threshold=3):
318
319
        # encoder    
320
        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
321
        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
322
        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
323
        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
324
        enc5 = self.encoder_block5(input=enc4_pool,
325
                                   include_pool=False,
326
                                   training=training)
327
328
329
330
        # decoder
331
        dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
332
        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
333
        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
334
        prob_dist_out = self.decoder_block1(input=dec2,
335
                                            skip_conn=enc1,
336
                                            segmentation=True,
337
                                            training=training)
338
        if predict:
339
            seg_logits_out = self.decoder_block1(input=dec2,
340
                                                 skip_conn=enc1,
341
                                                 segmentation=True,
342
                                                 training=training,
343
                                                 prob_dist=False)
344
345
        # This prediction is included to allow one to seta threshold for the 
346
        # uncertainty, deemed an arbitrary value that corresponds to the 
347
        # maximum value of the logits predicted at a specific point in the 
348
        # image. It only includes predictions for the vascular and neural 
349
        # tissues if they are above the confidence threshold, if they are below
350
        # the threshold the predictions are defaulted to muscle, connective,
351
        # or background.
352
        
353
        if predict:
354
            # rename the value for consistency and write protection.
355
            y_pred = seg_logits_out
356
            pred_shape = (1,1024,1024,6)
357
            # Getting an image-sized preliminary segmentation prediction
358
            squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
359
360
            # initializing the variable used for storing the maximum logits at 
361
            # each pixel location.
362
            max_value_predictions = tf.zeros((1024,1024))
363
364
            # cycle through all the classes 
365
            for idx in range(6):
366
                
367
                # current class logits
368
                current_slice = tf.squeeze(y_pred[:,:,:,idx])
369
                # find the locations where this class is predicted
370
                current_indices = squeezed_prediction == idx
371
                # define the shape so that this function can run in graph mode
372
                # and not need eager execution.
373
                current_indices.set_shape((1024,1024))
374
                # Get the indices of where the idx class is predicted
375
                indices = tf.where(squeezed_prediction == idx)
376
                # get the output of boolean_mask to enable scatter update of the
377
                # tensor. This is required because tensors do not support 
378
                # mask indexing.
379
                values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
380
                # Place the maximum logit values at each point in an 
381
                # image-size matrix, indicating the confidence in the prediction
382
                # at each pixel. 
383
                max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
384
            
385
            for idx in [3,4]:
386
                mask_list = []
387
                for idx2 in range(6):
388
                    if idx2 == idx:
389
390
391
                        if idx2 == 4:
392
                            threshold = threshold - 1
393
394
                            
395
                        mid_mask = max_value_predictions<threshold
396
                        mask_list.append(mid_mask.astype(tf.float32))
397
                    else:
398
                        mask_list.append(tf.zeros((1024,1024)))
399
400
                mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
401
402
                indexes = tf.where(mask)
403
                values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
404
405
                seg_logits_out = tf.tensor_scatter_nd_update(seg_logits_out,indexes,values_updates.astype(tf.float32))
406
                prob_dist_out = layers.Softmax(dtype='float32')(seg_logits_out)
407
408
409
            
410
        return(prob_dist_out)
411
412
413
# --------------------------------------------------
414
415
def get_image_blocks(image,tile_distance=512,tile_size=1024):
416
    '''Receives an image as well as a minimum distance between tiles. 
417
       Returns the name of the image processed, the image dimensions, and a list
418
       of tile centers evenly distributed across the tissue surface.'''
419
    image_dimensions = image.shape
420
421
    safe_mask = np.zeros([image_dimensions[0],image_dimensions[1]])
422
    safe_mask[int(tile_size/2):image_dimensions[0]-int(tile_size/2),
423
              int(tile_size/2):image_dimensions[1]-int(tile_size/2)] = 1
424
425
    grid_0 = np.arange(0,image_dimensions[0],tile_distance)
426
    grid_1 = np.arange(0,image_dimensions[1],tile_distance)
427
428
    
429
430
    center_indexes = []
431
432
    for grid0 in grid_0:
433
        for grid1 in grid_1:
434
            if safe_mask[grid0,grid1]:
435
                center_indexes.append([grid0,grid1])
436
437
    return([image_dimensions,center_indexes])
438
439
# --------------------------------------------------
440
441
def get_reduced_tile_indexes(tile_center,returned_size=1024):
442
    start_0 = int(tile_center[0] - returned_size/2)
443
    end_0 = int(tile_center[0] + returned_size/2)
444
445
    start_1 = int(tile_center[1] - returned_size/2)
446
    end_1 = int(tile_center[1] + returned_size/2)
447
448
    return([start_0,end_0],[start_1,end_1])
449
450
# --------------------------------------------------
451
452
def segment_tiles(unet,center_indexes,image,threshold=3,scaling_factor=1,tile_size=1024):
453
    
454
    m,n,z = image.shape
455
    segmentation = np.zeros((m,n))
456
457
    for idx in tqdm.tqdm(range(len(center_indexes))):
458
        center = center_indexes[idx]
459
        dim0, dim1 = get_reduced_tile_indexes(center,tile_size)
460
        sub_sectioned_tile = image[dim0[0]:dim0[1],dim1[0]:dim1[1]] 
461
462
        full_tile_dim0,full_tile_dim1,z = sub_sectioned_tile.shape
463
464
        color_tile = sub_sectioned_tile[:,:,0:3]
465
466
        if scaling_factor > 1:
467
            height = color_tile.shape[0]
468
            width = color_tile.shape[1]
469
470
            height2 = int(height/scaling_factor)
471
            width2 = int(width/scaling_factor)
472
            
473
            color_tile = cv.resize(color_tile,[height2,width2],cv.INTER_AREA)
474
475
        color_tile = color_tile[None,:,:,:]
476
477
        prediction = unet(color_tile,predict=True,threshold=threshold)
478
479
        prediction_tile = np.squeeze(np.asarray(tf.argmax(prediction,axis=-1)).astype(np.float32).copy())
480
481
        if scaling_factor > 1:
482
            prediction_tile = cv.resize(prediction_tile,[full_tile_dim0,full_tile_dim1],cv.INTER_NEAREST)
483
484
485
        dim0, dim1 = get_reduced_tile_indexes(center,returned_size=512)
486
487
        # fix this hard coding of the tile indexes for the prediction
488
        segmentation[dim0[0]:dim0[1],dim1[0]:dim1[1]] = prediction_tile[256:768,256:768]
489
490
    return(segmentation)
491
492
# --------------------------------------------------
493
494
def segment_directory(JPG_directory,
495
                      unet,tile_size=2048,
496
                      tile_distance=512,
497
                      scaling_factor=2,
498
                      HeartID='0',
499
                      threshold=3,
500
                      ):
501
    os.chdir(JPG_directory)
502
503
    out_directory = f'./../{HeartID}_uNet_Segmentations/'
504
505
    # create the directory for saving if it doesn't already exist
506
    if not os.path.isdir(out_directory):
507
        os.mkdir(out_directory)
508
509
    os.chdir(out_directory)
510
511
    file_names = tf.io.gfile.glob(JPG_directory + HeartID + '*.jpg')
512
513
    for idx,file in enumerate(file_names):
514
        print(f'segmenting file {idx} of {len(file_names)}')
515
516
        file_id = file.split('/')[-1].split('.')[0]
517
518
        image = cv.imread(file,cv.IMREAD_UNCHANGED)
519
        image = cv.copyMakeBorder(image,4000,4000,4000,4000,cv.BORDER_REPLICATE)
520
521
        dimensions,center_indexes = get_image_blocks(image,
522
                                                    tile_distance=tile_distance,
523
                                                    tile_size=tile_size
524
                                                    )
525
        try:
526
527
            segmentation = segment_tiles(unet,
528
                             center_indexes,
529
                             image,
530
                             threshold=threshold,
531
                             scaling_factor=scaling_factor,
532
                             tile_size=tile_size)
533
534
        except Exception as e:
535
            print(file)
536
537
        cv.imwrite(
538
            file_id + 
539
            f'_uNetSegmentation.png',
540
            segmentation
541
            )
542
543
    return()
544
545
546
# --------------------------------------------------
547
# --------------------------------------------------
548
549
550
def main() -> None:
551
    """ Main function for segmenting the provided directory with the given 
552
    uNet weights. """
553
554
    args = get_args()
555
    uNet_file = args.uNet_weights
556
    JPG_directory = args.jpg_directory
557
    HeartID = args.heart_id
558
    GPU_ID = args.GPU_id
559
    threshold = args.threshold
560
    tile_size = 1024
561
    
562
    os.environ["CUDA_VISIBLE_DEVICES"]=GPU_ID
563
    gpus = tf.config.list_physical_devices('GPU')
564
565
    if gpus:
566
        # Restrict TensorFlow to only allocate 8GB of memory on the first GPU
567
        try:
568
            tf.config.set_logical_device_configuration(
569
                gpus[0],
570
                [tf.config.LogicalDeviceConfiguration(memory_limit=8000)])
571
            logical_gpus = tf.config.list_logical_devices('GPU')
572
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
573
        except RuntimeError as e:
574
            # Virtual devices must be set before GPUs have been initialized
575
            print(e)
576
577
    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
578
    unet = uNet(filter_multiplier=12)
579
    _ = unet(sample_data)
580
    unet.summary()
581
582
    unet.load_weights(uNet_file)
583
584
    segment_directory(JPG_directory,
585
                    unet,
586
                    tile_size=tile_size,
587
                    tile_distance=512,
588
                    scaling_factor=1,
589
                    HeartID=HeartID,
590
                    threshold=3,
591
                    )
592
593
594
# --------------------------------------------------
595
if __name__ == '__main__':
596
    main()