Diff of /directory_segmentation.py [000000] .. [3b7fea]

Switch to side-by-side view

--- a
+++ b/directory_segmentation.py
@@ -0,0 +1,596 @@
+#!/usr/bin/env python3
+"""
+Author : briancottle 
+Date   : 2022-12-14
+Purpose: Segment an entire directory of histological .JPG files using the 
+user provided uNet weights generated and saved by the uNet_Subclassed_SCCE.py
+file. 
+"""
+
+import argparse
+from typing import NamedTuple, TextIO
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras import mixed_precision
+from tensorflow.python.ops.numpy_ops import np_config
+np_config.enable_numpy_behavior()
+from skimage import measure
+from skimage import morphology
+from scipy import ndimage
+import cv2 as cv
+import os
+import matplotlib.pyplot as plt
+import tqdm
+from natsort import natsorted
+
+# --------------------------------------------------
+
+class Args(NamedTuple):
+    """ Command-line arguments """
+    uNet_weights: str
+    jpg_directory: str
+    heart_id: str
+    GPU_id: str
+    threshold: int
+
+
+# --------------------------------------------------
+
+def get_args() -> Args:
+    """ Get command-line arguments """
+
+    parser = argparse.ArgumentParser(
+        description='Providing command line arguments.',
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+    parser.add_argument('-u',
+                        '--uNet_weights',
+                        type=str,
+                        metavar='uNet',
+                        help='Path to an .h5 file created using the '
+                        'architecture established in uNet_Subclassed_SCCE.py'
+                        ' file.')
+
+    parser.add_argument('-d',
+                        '--jpg_directory',
+                        help='The directory containing the .jpg files for '
+                        'segmentation',
+                        metavar='Dir',
+                        type=str,
+                        default='')
+
+    parser.add_argument('-i',
+                        '--heart_id',
+                        help='The two digit heart ID for the data being '
+                        'segmented',
+                        metavar='HID',
+                        type=str,
+                        default='')
+
+    parser.add_argument('-g',
+                        '--GPU_id',
+                        help='the GPU number for the session to run on',
+                        metavar='GPU',
+                        type=str,
+                        default='')
+
+    parser.add_argument('-t',
+                        '--threshold',
+                        help='what arbitrary threshold to use for confidence segmenting',
+                        metavar='int',
+                        type=int,
+                        default=3)
+
+
+
+    args = parser.parse_args()
+
+    return Args(args.uNet_weights,
+                args.jpg_directory, 
+                args.heart_id,
+                args.GPU_id,
+                args.threshold,
+                )
+
+
+# --------------------------------------------------
+# --------------------------------------------------
+
+class EncoderBlock(layers.Layer):
+    '''This function returns an encoder block with two convolutional layers and 
+       an option for returning both a max-pooled output with a stride and pool 
+       size of (2,2) and the output of the second convolution for skip 
+       connections implemented later in the network during the decoding 
+       section. All padding is set to "same" for cleanliness.
+       
+       When initializing it receives the number of filters to be used in both
+       of the convolutional layers as well as the kernel size and stride for 
+       those same layers. It also receives the trainable variable for use with
+       the batch normalization layers.'''
+
+    def __init__(self,
+                 filters,
+                 kernel_size=(3,3),
+                 strides=(1,1),
+                 trainable=True,
+                 name='encoder_block',
+                 **kwargs):
+
+        super(EncoderBlock,self).__init__(trainable, name, **kwargs)
+        # When initializing this object receives a trainable parameter for
+        # freezing the convolutional layers. 
+
+        # including the image normalization within the network for easier image
+        # processing during inference
+        self.image_normalization = layers.Rescaling(scale=1./255)
+
+        # below creates the first of two convolutional layers
+        self.conv1 = layers.Conv2D(filters=filters,
+                      kernel_size=kernel_size,
+                      strides=strides,
+                      padding='same',
+                      name='encoder_conv1',
+                      trainable=trainable)
+
+        # second of two convolutional layers
+        self.conv2 = layers.Conv2D(filters=filters,
+                      kernel_size=kernel_size,
+                      strides=strides,
+                      padding='same',
+                      name='encoder_conv2',
+                      trainable=trainable)
+
+        # creates the max-pooling layer for downsampling the image.
+        self.enc_pool = layers.MaxPool2D(pool_size=(2,2),
+                                    strides=(2,2),
+                                    padding='same',
+                                    name='enc_pool')
+
+        # ReLU layer for activations.
+        self.ReLU = layers.ReLU()
+        
+        # both batch normalization layers for use with their corresponding
+        # convolutional layers.
+        self.batch_norm1 = tf.keras.layers.BatchNormalization()
+        self.batch_norm2 = tf.keras.layers.BatchNormalization()
+
+    def call(self,input,normalization=False,training=True,include_pool=True):
+        
+        # first conv of the encoder block
+        if normalization:
+            x = self.image_normalization(input)
+            x = self.conv1(x)
+        else:
+            x = self.conv1(input)
+
+        x = self.batch_norm1(x,training=training)
+        x = self.ReLU(x)
+
+        # second conv of the encoder block
+        x = self.conv2(x)
+        x = self.batch_norm2(x,training=training)
+        x = self.ReLU(x)
+        
+        # calculate and include the max pooling layer if include_pool is true.
+        # This output is used for the skip connections later in the network.
+        if include_pool:
+            pooled_x = self.enc_pool(x)
+            return(x,pooled_x)
+
+        else:
+            return(x)
+
+
+# --------------------------------------------------
+
+class DecoderBlock(layers.Layer):
+    '''This function returns a decoder block that when called receives both an
+       input and a "skip connection". The input is passed to the 
+       "up convolution" or transpose conv layer to double the dimensions before
+       being concatenated with its associated skip connection from the encoder
+       section of the network. All padding is set to "same" for cleanliness. 
+       The decoder block also has an option for including an additional 
+       "segmentation" layer, which is a (1,1) convolution with 4 filters, which
+       produces the logits for the one-hot encoded ground truth. 
+       
+       When initializing it receives the number of filters to be used in the
+       up convolutional layer as well as the other two forward convolutions. 
+       The received kernel_size and stride is used for the forward convolutions,
+       with the up convolution kernel and stride set to be (2,2).'''
+    def __init__(self,
+                 filters,
+                 trainable=True,
+                 kernel_size=(3,3),
+                 strides=(1,1),
+                 name='DecoderBlock',
+                 **kwargs):
+
+        super(DecoderBlock,self).__init__(trainable, name, **kwargs)
+
+        # creating the up convolution layer
+        self.up_conv = layers.Conv2DTranspose(filters=filters,
+                                              kernel_size=(2,2),
+                                              strides=(2,2),
+                                              padding='same',
+                                              name='decoder_upconv',
+                                              trainable=trainable)
+
+        # the first of two forward convolutional layers
+        self.conv1 = layers.Conv2D(filters=filters,
+                                   kernel_size=kernel_size,
+                                   strides=strides,
+                                   padding='same',
+                                   name ='decoder_conv1',
+                                   trainable=trainable)
+
+        # second convolutional layer
+        self.conv2 = layers.Conv2D(filters=filters,
+                                   kernel_size=kernel_size,
+                                   strides=strides,
+                                   padding='same',
+                                   name ='decoder_conv2',
+                                   trainable=trainable)
+
+        # this creates the output prediction logits layer.
+        self.seg_out = layers.Conv2D(filters=6,
+                        kernel_size=(1,1),
+                        name='conv_feature_map')
+
+        # ReLU for activation of all above layers
+        self.ReLU = layers.ReLU()
+        
+        # the individual batch normalization layers for their respective 
+        # convolutional layers.
+        self.batch_norm1 = tf.keras.layers.BatchNormalization()
+        self.batch_norm2 = tf.keras.layers.BatchNormalization()
+
+
+    def call(self,input,skip_conn,training=True,segmentation=False,prob_dist=True):
+        
+        up = self.up_conv(input) # perform image up convolution
+        # concatenate the input and the skip_conn along the features axis
+        concatenated = layers.concatenate([up,skip_conn],axis=-1)
+
+        # first convolution 
+        x = self.conv1(concatenated)
+        x = self.batch_norm1(x,training=training)
+        x = self.ReLU(x)
+
+        # second convolution
+        x = self.conv2(x)
+        x = self.batch_norm2(x,training=training)
+        x = self.ReLU(x)
+
+        # if segmentation is True, then run the segmentation (1,1) convolution
+        # and use the Softmax to produce a probability distribution.
+        if segmentation:
+            seg = self.seg_out(x)
+            # deliberately set as "float32" to ensure proper calculation if 
+            # switching to mixed precision for efficiency
+            if prob_dist:
+                seg = layers.Softmax(dtype='float32')(seg)
+
+            return(seg)
+
+        else:
+            return(x)
+
+# --------------------------------------------------
+
+class uNet(keras.Model):
+    '''This is a sub-classed model that uses the encoder and decoder blocks
+       defined above to create a custom unet. The differences from the original 
+       paper include a variable filter scalar (filter_multiplier), batch 
+       normalization between each convolutional layer and the associated ReLU 
+       activation, as well as feature normalization implemented in the first 
+       layer of the network.'''
+    def __init__(self,filter_multiplier=2,**kwargs):
+        super(uNet,self).__init__()
+        
+        # Defining encoder blocks
+        self.encoder_block1 = EncoderBlock(filters=2*filter_multiplier,
+                                           name='Enc1')
+        self.encoder_block2 = EncoderBlock(filters=4*filter_multiplier,
+                                           name='Enc2')
+        self.encoder_block3 = EncoderBlock(filters=8*filter_multiplier,
+                                           name='Enc3')
+        self.encoder_block4 = EncoderBlock(filters=16*filter_multiplier,
+                                           name='Enc4')
+        self.encoder_block5 = EncoderBlock(filters=32*filter_multiplier,
+                                           name='Enc5')
+
+        # Defining decoder blocks. The names are in reverse order to make it 
+        # (hopefully) easier to understand which skip connections are associated
+        # with which decoder layers.
+        self.decoder_block4 = DecoderBlock(filters=16*filter_multiplier,
+                                           name='Dec4')
+        self.decoder_block3 = DecoderBlock(filters=8*filter_multiplier,
+                                           name='Dec3')
+        self.decoder_block2 = DecoderBlock(filters=4*filter_multiplier,
+                                           name='Dec2')
+        self.decoder_block1 = DecoderBlock(filters=2*filter_multiplier,
+                                           name='Dec1')
+
+
+    def call(self,inputs,training,predict=False,threshold=3):
+
+        # encoder    
+        enc1,enc1_pool = self.encoder_block1(input=inputs,normalization=True,training=training)
+        enc2,enc2_pool = self.encoder_block2(input=enc1_pool,training=training)
+        enc3,enc3_pool = self.encoder_block3(input=enc2_pool,training=training)
+        enc4,enc4_pool = self.encoder_block4(input=enc3_pool,training=training)
+        enc5 = self.encoder_block5(input=enc4_pool,
+                                   include_pool=False,
+                                   training=training)
+
+
+
+        # decoder
+        dec4 = self.decoder_block4(input=enc5,skip_conn=enc4,training=training)
+        dec3 = self.decoder_block3(input=dec4,skip_conn=enc3,training=training)
+        dec2 = self.decoder_block2(input=dec3,skip_conn=enc2,training=training)
+        prob_dist_out = self.decoder_block1(input=dec2,
+                                            skip_conn=enc1,
+                                            segmentation=True,
+                                            training=training)
+        if predict:
+            seg_logits_out = self.decoder_block1(input=dec2,
+                                                 skip_conn=enc1,
+                                                 segmentation=True,
+                                                 training=training,
+                                                 prob_dist=False)
+
+        # This prediction is included to allow one to seta threshold for the 
+        # uncertainty, deemed an arbitrary value that corresponds to the 
+        # maximum value of the logits predicted at a specific point in the 
+        # image. It only includes predictions for the vascular and neural 
+        # tissues if they are above the confidence threshold, if they are below
+        # the threshold the predictions are defaulted to muscle, connective,
+        # or background.
+        
+        if predict:
+            # rename the value for consistency and write protection.
+            y_pred = seg_logits_out
+            pred_shape = (1,1024,1024,6)
+            # Getting an image-sized preliminary segmentation prediction
+            squeezed_prediction = tf.squeeze(tf.argmax(y_pred,axis=-1))
+
+            # initializing the variable used for storing the maximum logits at 
+            # each pixel location.
+            max_value_predictions = tf.zeros((1024,1024))
+
+            # cycle through all the classes 
+            for idx in range(6):
+                
+                # current class logits
+                current_slice = tf.squeeze(y_pred[:,:,:,idx])
+                # find the locations where this class is predicted
+                current_indices = squeezed_prediction == idx
+                # define the shape so that this function can run in graph mode
+                # and not need eager execution.
+                current_indices.set_shape((1024,1024))
+                # Get the indices of where the idx class is predicted
+                indices = tf.where(squeezed_prediction == idx)
+                # get the output of boolean_mask to enable scatter update of the
+                # tensor. This is required because tensors do not support 
+                # mask indexing.
+                values_updates = tf.boolean_mask(current_slice,current_indices).astype(tf.double)
+                # Place the maximum logit values at each point in an 
+                # image-size matrix, indicating the confidence in the prediction
+                # at each pixel. 
+                max_value_predictions = tf.tensor_scatter_nd_update(max_value_predictions,indices,values_updates.astype(tf.float32))
+            
+            for idx in [3,4]:
+                mask_list = []
+                for idx2 in range(6):
+                    if idx2 == idx:
+
+
+                        if idx2 == 4:
+                            threshold = threshold - 1
+
+                            
+                        mid_mask = max_value_predictions<threshold
+                        mask_list.append(mid_mask.astype(tf.float32))
+                    else:
+                        mask_list.append(tf.zeros((1024,1024)))
+
+                mask = tf.expand_dims(tf.stack(mask_list,axis=-1),axis=0)
+
+                indexes = tf.where(mask)
+                values_updates = tf.boolean_mask(tf.zeros(pred_shape),mask).astype(tf.double)
+
+                seg_logits_out = tf.tensor_scatter_nd_update(seg_logits_out,indexes,values_updates.astype(tf.float32))
+                prob_dist_out = layers.Softmax(dtype='float32')(seg_logits_out)
+
+
+            
+        return(prob_dist_out)
+
+
+# --------------------------------------------------
+
+def get_image_blocks(image,tile_distance=512,tile_size=1024):
+    '''Receives an image as well as a minimum distance between tiles. 
+       Returns the name of the image processed, the image dimensions, and a list
+       of tile centers evenly distributed across the tissue surface.'''
+    image_dimensions = image.shape
+
+    safe_mask = np.zeros([image_dimensions[0],image_dimensions[1]])
+    safe_mask[int(tile_size/2):image_dimensions[0]-int(tile_size/2),
+              int(tile_size/2):image_dimensions[1]-int(tile_size/2)] = 1
+
+    grid_0 = np.arange(0,image_dimensions[0],tile_distance)
+    grid_1 = np.arange(0,image_dimensions[1],tile_distance)
+
+    
+
+    center_indexes = []
+
+    for grid0 in grid_0:
+        for grid1 in grid_1:
+            if safe_mask[grid0,grid1]:
+                center_indexes.append([grid0,grid1])
+
+    return([image_dimensions,center_indexes])
+
+# --------------------------------------------------
+
+def get_reduced_tile_indexes(tile_center,returned_size=1024):
+    start_0 = int(tile_center[0] - returned_size/2)
+    end_0 = int(tile_center[0] + returned_size/2)
+
+    start_1 = int(tile_center[1] - returned_size/2)
+    end_1 = int(tile_center[1] + returned_size/2)
+
+    return([start_0,end_0],[start_1,end_1])
+
+# --------------------------------------------------
+
+def segment_tiles(unet,center_indexes,image,threshold=3,scaling_factor=1,tile_size=1024):
+    
+    m,n,z = image.shape
+    segmentation = np.zeros((m,n))
+
+    for idx in tqdm.tqdm(range(len(center_indexes))):
+        center = center_indexes[idx]
+        dim0, dim1 = get_reduced_tile_indexes(center,tile_size)
+        sub_sectioned_tile = image[dim0[0]:dim0[1],dim1[0]:dim1[1]] 
+
+        full_tile_dim0,full_tile_dim1,z = sub_sectioned_tile.shape
+
+        color_tile = sub_sectioned_tile[:,:,0:3]
+
+        if scaling_factor > 1:
+            height = color_tile.shape[0]
+            width = color_tile.shape[1]
+
+            height2 = int(height/scaling_factor)
+            width2 = int(width/scaling_factor)
+            
+            color_tile = cv.resize(color_tile,[height2,width2],cv.INTER_AREA)
+
+        color_tile = color_tile[None,:,:,:]
+
+        prediction = unet(color_tile,predict=True,threshold=threshold)
+
+        prediction_tile = np.squeeze(np.asarray(tf.argmax(prediction,axis=-1)).astype(np.float32).copy())
+
+        if scaling_factor > 1:
+            prediction_tile = cv.resize(prediction_tile,[full_tile_dim0,full_tile_dim1],cv.INTER_NEAREST)
+
+
+        dim0, dim1 = get_reduced_tile_indexes(center,returned_size=512)
+
+        # fix this hard coding of the tile indexes for the prediction
+        segmentation[dim0[0]:dim0[1],dim1[0]:dim1[1]] = prediction_tile[256:768,256:768]
+
+    return(segmentation)
+
+# --------------------------------------------------
+
+def segment_directory(JPG_directory,
+                      unet,tile_size=2048,
+                      tile_distance=512,
+                      scaling_factor=2,
+                      HeartID='0',
+                      threshold=3,
+                      ):
+    os.chdir(JPG_directory)
+
+    out_directory = f'./../{HeartID}_uNet_Segmentations/'
+
+    # create the directory for saving if it doesn't already exist
+    if not os.path.isdir(out_directory):
+        os.mkdir(out_directory)
+
+    os.chdir(out_directory)
+
+    file_names = tf.io.gfile.glob(JPG_directory + HeartID + '*.jpg')
+
+    for idx,file in enumerate(file_names):
+        print(f'segmenting file {idx} of {len(file_names)}')
+
+        file_id = file.split('/')[-1].split('.')[0]
+
+        image = cv.imread(file,cv.IMREAD_UNCHANGED)
+        image = cv.copyMakeBorder(image,4000,4000,4000,4000,cv.BORDER_REPLICATE)
+
+        dimensions,center_indexes = get_image_blocks(image,
+                                                    tile_distance=tile_distance,
+                                                    tile_size=tile_size
+                                                    )
+        try:
+
+            segmentation = segment_tiles(unet,
+                             center_indexes,
+                             image,
+                             threshold=threshold,
+                             scaling_factor=scaling_factor,
+                             tile_size=tile_size)
+
+        except Exception as e:
+            print(file)
+
+        cv.imwrite(
+            file_id + 
+            f'_uNetSegmentation.png',
+            segmentation
+            )
+
+    return()
+
+
+# --------------------------------------------------
+# --------------------------------------------------
+
+
+def main() -> None:
+    """ Main function for segmenting the provided directory with the given 
+    uNet weights. """
+
+    args = get_args()
+    uNet_file = args.uNet_weights
+    JPG_directory = args.jpg_directory
+    HeartID = args.heart_id
+    GPU_ID = args.GPU_id
+    threshold = args.threshold
+    tile_size = 1024
+    
+    os.environ["CUDA_VISIBLE_DEVICES"]=GPU_ID
+    gpus = tf.config.list_physical_devices('GPU')
+
+    if gpus:
+        # Restrict TensorFlow to only allocate 8GB of memory on the first GPU
+        try:
+            tf.config.set_logical_device_configuration(
+                gpus[0],
+                [tf.config.LogicalDeviceConfiguration(memory_limit=8000)])
+            logical_gpus = tf.config.list_logical_devices('GPU')
+            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
+        except RuntimeError as e:
+            # Virtual devices must be set before GPUs have been initialized
+            print(e)
+
+    sample_data = np.zeros((1,1024,1024,3)).astype(np.int8)
+    unet = uNet(filter_multiplier=12)
+    _ = unet(sample_data)
+    unet.summary()
+
+    unet.load_weights(uNet_file)
+
+    segment_directory(JPG_directory,
+                    unet,
+                    tile_size=tile_size,
+                    tile_distance=512,
+                    scaling_factor=1,
+                    HeartID=HeartID,
+                    threshold=3,
+                    )
+
+
+# --------------------------------------------------
+if __name__ == '__main__':
+    main()