Diff of /TF_WriteToRecordFile.py [000000] .. [3b7fea]

Switch to side-by-side view

--- a
+++ b/TF_WriteToRecordFile.py
@@ -0,0 +1,427 @@
+# %% importing packages
+
+import numpy as np
+import tensorflow as tf
+from skimage import measure
+import skimage.transform as transform
+import cv2 as cv
+import os
+import tqdm
+import matplotlib.pyplot as plt
+import random
+
+# %% Citations
+#############################################################
+#############################################################
+
+# https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c
+# https://keras.io/examples/keras_recipes/creating_tfrecords/
+# https://www.tensorflow.org/tutorials/load_data/tfrecord
+
+# %% Defining TF Records Helper Functions
+#############################################################
+#############################################################
+
+# These following functions were created from the TDS blog and 
+# the tensorflow suggestions/tutorial
+
+def int64_feature(value):
+  """Returns an int64_list from a bool / enum / int / uint."""
+  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+def float_feature_list(value):
+    """Returns a list of float_list from a float / double."""
+    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+
+def bytes_feature(value):
+    """Returns a bytes_list from a string / byte."""
+    if isinstance(value, type(tf.constant(0))): # if value is tensor
+        value = value.numpy() # get value of tensor
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+def parse_image(full_image,seg,bbox,image_name):
+
+    data = {
+        'height' : int64_feature(full_image.shape[0]),
+        'width' : int64_feature(full_image.shape[1]),
+        'raw_image' : bytes_feature(
+            tf.io.serialize_tensor(full_image)),
+        'raw_seg' : bytes_feature(
+            tf.io.serialize_tensor(seg)),
+        'bbox_x' : float_feature_list(bbox[0]),
+        'bbox_y' : float_feature_list(bbox[1]),
+        'bbox_width' : float_feature_list(bbox[2]),
+        'bbox_height' : float_feature_list(bbox[3]),
+        'name' : bytes_feature(bytes(image_name, encoding='utf-8')),
+    }
+
+    parsed_image = tf.train.Example(features=tf.train.Features(feature=data))
+
+    return(parsed_image)
+
+# Defining Functions
+#############################################################
+#############################################################
+
+def get_bounding_boxes(binary_image):
+    # This function receives a binary image, and returns a list of the
+    # bounding boxes that surround the positive connected components
+
+    labeled_image = measure.label(binary_image) # labeling image
+    regions = measure.regionprops(labeled_image) # getting region props
+    
+    # lists for storing the sequence of x, y, width, and height data for 
+    # the image. They are separate so that each can be its own list that 
+    # is stored, seemed to make sense with the limitations of the 
+    # float_feature_list thing. 
+    box_x = []
+    box_y = []
+    box_width = []
+    box_height = []
+
+    # iterating over the number of regions found in the image
+    for region in regions:
+        # retrieving [min_row, min_col, max_row, max_col]
+        bounding_box = region['bbox'] 
+
+        # Calculating the center of the bounding box, as this is a common
+        # format for the box parameters in SSD networks
+        x = np.floor(np.mean([bounding_box[2],bounding_box[0]]))
+        y = np.floor(np.mean([bounding_box[3],bounding_box[1]]))
+
+        # retriving the width and height, also common format
+        width = bounding_box[2]-bounding_box[0]
+        height = bounding_box[3]-bounding_box[1]
+
+        # appending to respective lists for storage
+        box_x.append(x)
+        box_y.append(y)
+        box_width.append(width)
+        box_height.append(height)
+
+
+    return([box_x,box_y,box_width,box_height])
+
+#############################################################
+
+def load_image_names(folder):
+    # This function reads the images within a folder while filtering 
+    # out the weird invisible files that macos includes in their folders
+    
+    file_list = []
+    for file_name in os.listdir(folder):
+
+        # check if the first character of the name is a '.', skip if so
+        if file_name[0] != '.': 
+            file_list.append(file_name)
+
+    return(file_list)
+
+#############################################################
+
+def parse_image_bbox(file_name,bbox_class_id,reduction_size=1):
+    # This function receives a name and the class id of which you want to 
+    # provide bounding boxes for in the dataset
+    image = cv.imread(file_name,cv.IMREAD_UNCHANGED)
+    passed = False
+    parsed_image = 0
+    if image.shape[0] == image.shape[1]:
+        passed = True
+        # separating out the color image
+        try:
+            color_image = image[:,:,0:3]
+        except Exception as e:
+            print(file_name)
+            print(image.shape)
+            print(e)
+            
+        # downsampling the color image
+        if reduction_size > 1:
+            height = color_image.shape[0]
+            width = color_image.shape[1]
+
+            height2 = int(height/reduction_size)
+            width2 = int(width/reduction_size)
+
+            color_image = cv.resize(color_image,[height2,width2],cv.INTER_AREA)
+
+        # getting the segmentation for the bbox production, making compensations
+        # for odd segmentations
+        seg = image[:,:,3]
+        # seg[seg==4] = 2
+        # seg[seg==5] = 4
+        # seg[seg==6] = 5
+        # seg[seg==7] = 6
+
+        if reduction_size > 1:
+            height = seg.shape[0]
+            width = seg.shape[1]
+
+            height2 = int(height/reduction_size)
+            width2 = int(width/reduction_size)
+
+            seg = cv.resize(seg,[height2,width2],cv.INTER_NEAREST)
+
+        # creating the binary image for bboxes
+        bbox_seg = seg == bbox_class_id
+        bbox = get_bounding_boxes(bbox_seg)
+
+        parsed_image = parse_image(color_image,seg,bbox,image_name=file_name)
+
+    return(parsed_image,passed)
+    
+
+#############################################################
+
+def get_shard_sizes(file_names,max_files_per_shard):
+    # This function receives a list of file names and the maximum number of 
+    # files you want to save in a shard.
+    num_images = len(file_names)
+    # determining the number of splits for the image. The +1 collects any 
+    # stragglers that don't completely "fill up" a shard. It is removed if the 
+    # number comes out with no remainder. 
+    num_splits = num_images//max_files_per_shard + 1
+    if num_images%max_files_per_shard == 0:
+        num_splits -= 1
+    
+    print(f'Using {num_splits} shards to store {num_images} images.')
+
+    return(num_splits,max_files_per_shard)
+
+#############################################################
+
+def write_all_images_to_shards(file_names,
+                               num_splits,
+                               max_files_per_shard,
+                               bbox_id=5,
+                               reduction_size=1):
+    # this function receives a list of file names, the number of splits
+    # produced by get_shard_sizes, the how many files will be put in each
+    # shard, and the class id of which segmentation you want to produce
+    # bounding boxes for. In the future you could easily add the functionality
+    # to produce bboxes for both vasculature and neural tissues.
+    out_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_6/'
+
+    # create the directory for saving if it doesn't already exist
+    if not os.path.isdir(out_directory):
+        os.mkdir(out_directory)
+
+    # keeps track of how many files have been written total    
+    files_written = 0
+
+    # displaying progress based on number of shards successfully saved
+    for idx in tqdm.tqdm(range(num_splits)):
+    
+        current_shard_name = out_directory + \
+            f'shard_{idx+1}_of_{num_splits}.tfrecords'
+        
+        # create the writer for the current shard
+        writer = tf.io.TFRecordWriter(current_shard_name)
+        
+        # keep track of how many files we've put in this shard
+        num_files_this_shard = 0
+
+        # exit if we hit the max number of files for this shard
+        while num_files_this_shard < max_files_per_shard:
+            # keeping track of names of files across shards
+            image_idx = idx*max_files_per_shard + num_files_this_shard
+            # if we hit the end of all the files, stop this shard
+            if image_idx == len(file_names):
+                break
+            # get a parsed image
+            parsed_image,passed = parse_image_bbox(file_names[image_idx],
+                                            bbox_id,
+                                            reduction_size=reduction_size)
+
+            # add the current image to the tfrecord file
+            if passed:
+                writer.write(parsed_image.SerializeToString())
+
+                num_files_this_shard += 1
+                files_written += 1
+            else:
+                print(f'{file_names[image_idx]} failed, skipping')
+                num_files_this_shard += 1
+
+        writer.close()
+    
+    print(f'{files_written} files have been written to the tfrecord.')
+    return(files_written)
+
+#############################################################
+
+def parse_tf_elements(element):
+    '''This function is the mapper function for retrieving examples from the
+       tfrecord'''
+
+    # create placeholders for all the features in each example
+    data = {
+        'height' : tf.io.FixedLenFeature([],tf.int64),
+        'width' : tf.io.FixedLenFeature([],tf.int64),
+        'raw_image' : tf.io.FixedLenFeature([],tf.string),
+        'raw_seg' : tf.io.FixedLenFeature([],tf.string),
+        'bbox_x' : tf.io.VarLenFeature(tf.float32),
+        'bbox_y' : tf.io.VarLenFeature(tf.float32),
+        'bbox_height' : tf.io.VarLenFeature(tf.float32),
+        'bbox_width' : tf.io.VarLenFeature(tf.float32),
+        'name' : tf.io.FixedLenFeature([],tf.string),
+    }
+
+    # pull out the current example
+    content = tf.io.parse_single_example(element, data)
+
+    # pull out each feature from the example 
+    height = content['height']
+    width = content['width']
+    raw_seg = content['raw_seg']
+    raw_image = content['raw_image']
+    name = content['name']
+    bbox_x = content['bbox_x']
+    bbox_y = content['bbox_y']
+    bbox_height = content['bbox_height']
+    bbox_width = content['bbox_width']
+
+    # convert the images to uint8, and reshape them accordingly
+    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
+    image = tf.reshape(image,shape=[height,width,3])
+    segmentation = tf.io.parse_tensor(raw_seg, out_type=tf.uint8)
+    segmentation = tf.reshape(segmentation,shape=[height,width,1])
+    one_hot_seg = tf.one_hot(tf.squeeze(segmentation),7,axis=-1)
+
+    # there currently is a bug with returning the bbox, but isn't necessary
+    # to fix for creating the initial uNet for segmentation exploration
+    
+    # bbox = [bbox_x,bbox_y,bbox_height,bbox_width]
+
+    return(image,one_hot_seg,name)
+
+#############################################################
+#############################################################
+
+# %%
+# writing the files to a new directory!
+dataset_directory = '/home/briancottle/Research/Semantic_Segmentation/sub_sampled_20221129'
+os.chdir(dataset_directory)
+file_names = load_image_names(dataset_directory)
+num_splits,max_files_per_shard = get_shard_sizes(file_names,400)
+
+# %%
+write_all_images_to_shards(file_names,
+                           num_splits,
+                           max_files_per_shard,
+                           bbox_id=5,
+                           reduction_size=1)
+
+# %% loading an example shard, and creating the mapped dataset
+os.chdir('/home/briancottle/Research/Semantic_Segmentation/dataset_shards_6/')
+dataset = tf.data.TFRecordDataset('shard_10_of_128.tfrecords')
+dataset = dataset.map(parse_tf_elements)
+# %% 
+# double checking some of the examples to make sure it all worked well!
+for sample in dataset.take(100):
+    print(sample[2])
+    plt.imshow(cv.cvtColor(np.asarray(sample[0]),cv.COLOR_BGR2RGB))
+    print(sample[0].shape)
+    plt.show()
+    seg = tf.argmax(sample[1],axis=-1)
+    plt.imshow(seg,vmin=0,vmax=6)
+    plt.show()
+    print(np.max(seg))
+    print(np.unique(seg))
+    
+
+
+# %% Checking for dud/empty files
+
+for file_name in file_names:
+    image = cv.imread(file_name,cv.IMREAD_UNCHANGED)
+    try:
+        print(image.shape)
+        assert image.shape == (1024,1024,3)
+    except Exception as e:
+        print(e)
+        print(file_name)
+    
+
+
+# %% Further data exploration to ensure proper storage in the record files!
+GT = []
+Images = []
+interested = 12
+count = 0
+for sample in dataset.take(30):
+    if count == interested:
+        GT.append(sample[1])
+        gt = tf.squeeze(sample[1])
+        Images.append(sample[0])
+        # print('layer 7')
+        # plt.imshow(gt[:,:,7])
+        # print(np.sum(gt[:,:,7]))
+        # plt.show()
+        print('layer 6')
+        plt.imshow(gt[:,:,6])
+        print(np.sum(gt[:,:,6]))
+        plt.show()
+        print('layer 5')
+        plt.imshow(gt[:,:,5])
+        print(np.sum(gt[:,:,5]))
+        plt.show()
+        print('layer 4')
+        plt.imshow(gt[:,:,4])
+        print(np.sum(gt[:,:,4]))
+        plt.show()
+        print('layer 3')
+        plt.imshow(gt[:,:,3])
+        print(np.sum(gt[:,:,3]))
+        plt.show()
+        print('layer 2')
+        plt.imshow(gt[:,:,2])
+        print(np.sum(gt[:,:,2]))
+        plt.show()
+        print('layer 1')
+        plt.imshow(gt[:,:,1])
+        print(np.sum(gt[:,:,1]))
+        plt.show()
+        print('layer 0')
+        plt.imshow(gt[:,:,0])
+        print(np.sum(gt[:,:,0]))
+        plt.show()
+    count += 1
+
+# %%
+# directory where the dataset shards are stored
+shard_dataset_directory = '/home/briancottle/Research/Semantic_Segmentation/dataset_shards_6'
+
+os.chdir(shard_dataset_directory)
+if not os.path.isdir('./train'):
+    os.mkdir('./train')
+    os.mkdir('./validate')
+    os.mkdir('./test')
+
+# only get the file names that follow the shard naming convention
+file_names = tf.io.gfile.glob(shard_dataset_directory + \
+                              "/shard_*_of_*.tfrecords")
+
+random.shuffle(file_names)
+
+# first 80% of names go to the training dataset. Following 10% go to the val
+# dataset, followed by last 10% go to the testing dataset.
+val_split_idx = int(0.80*len(file_names))
+test_split_idx = int(0.90*len(file_names))
+
+# separate the file names out
+train_files, val_files, test_files = file_names[:val_split_idx],\
+                                     file_names[val_split_idx:test_split_idx],\
+                                     file_names[test_split_idx:]
+
+
+for sub_dir,files in zip(['/train/shard','/validate/shard','/test/shard'],
+                         [train_files,val_files,test_files]):
+
+    for file_name in files:
+        new_name = file_name.split('/shard')[0] + sub_dir + file_name.split('/shard')[1]
+        os.rename(file_name,new_name)
+
+
+# %%