Diff of /data_generator.py [000000] .. [7b5b9f]

Switch to side-by-side view

--- a
+++ b/data_generator.py
@@ -0,0 +1,193 @@
+import tensorflow as tf
+import numpy as np
+import os
+# from matplotlib import pyplot as plt
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework.ops import convert_to_tensor
+import skimage as sk
+from skimage import transform
+import SimpleITK as sitk
+
+IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32)
+
+
+class ImageDataGenerator(object):
+
+    def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, buffer_size=5):
+
+        """Create a new ImageDataGenerator.
+        Receives a path string to a text file, where each line has a path string to an image and
+        separated by a space, then with an integer referring to the class number.
+
+        Args:
+            txt_file: path to the text file.
+            mode: either 'training' or 'validation'. Depending on this value, different parsing functions will be used.
+            batch_size: number of images per batch.
+            num_classes: number of classes in the dataset.
+            shuffle: wether or not to shuffle the data in the dataset and the initial file list.
+            buffer_size: number of images used as buffer for TensorFlows shuffling of the dataset.
+
+        Raises:
+            ValueError: If an invalid mode is passed.
+        """
+
+        self.txt_file = txt_file
+        self.num_classes = num_classes
+
+        # retrieve the data from the text file
+        self._read_txt_file()
+
+        # number of samples in the dataset
+        self.data_size = len(self.img_paths)
+
+        # initial shuffling of the file and label lists together
+        if shuffle:
+            self._shuffle_lists()
+
+        # convert lists to TF tensor
+        self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string)
+
+        # create dataset
+        data = tf.data.Dataset.from_tensor_slices((self.img_paths))
+
+        # repeat indefinitely (train.py will count the epochs)
+        data = data.repeat()
+
+        # distinguish between train/infer. when calling the parsing functions
+        self.get_patches_fn = lambda filename: tf.py_func(self.extract_patch, [filename, [384,384,3], 2], [tf.float32, tf.float32])
+
+        if mode == 'training':
+            data = data.map(self.get_patches_fn, num_parallel_calls=8)
+
+        elif mode == 'inference':
+            data = data.map(self._parse_function_inference, num_parallel_calls=8)
+
+        else:
+            raise ValueError("Invalid mode '%s'." % (mode))
+
+        # shuffle the first `buffer_size` elements of the dataset
+        if shuffle:
+            data = data.shuffle(buffer_size=buffer_size)
+
+        # create a new dataset with batches of images
+        data = data.batch(batch_size)
+
+        self.data = data
+
+    def _read_txt_file(self):
+        """Read the content of the text file and store it into lists."""
+        with open(self.txt_file, 'r') as f:
+            rows = f.readlines()
+            self.img_paths = [row[:-1] for row in rows]
+
+    def _shuffle_lists(self):
+        """Conjoined shuffling of the list of paths and labels."""
+        path = self.img_paths
+        permutation = np.random.permutation(self.data_size)
+        self.img_paths = []
+        for i in permutation:
+            self.img_paths.append(path[i])
+
+    def extract_patch(self, filename, patch_size, num_class, num_patches=1):
+        """Input parser for samples of the training set."""
+        # convert label number into one-hot-encoding
+
+        image, mask = self.parse_fn(filename) # get the image and its mask
+        image_patches = []
+        mask_patches = []
+        num_patches_now = 0
+
+        while num_patches_now < num_patches:
+            # z = np.random.randint(1, mask.shape[2]-1)
+            z = self.random_patch_center_z(mask, patch_size=patch_size) # define the centre of current patch
+            image_patch = image[:, :, z-1:z+2]
+            mask_patch  =  mask[:, :, z]
+            
+            image_patches.append(image_patch)
+            mask_patches.append(mask_patch)
+            num_patches_now += 1
+        image_patches = np.stack(image_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2])
+        mask_patches = np.stack(mask_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2])
+
+        mask_patches = self._label_decomp(mask_patches, num_cls=num_class) # make into 5D (batch_size, patch_size[0], patch_size[1], patch_size[2], num_classes)
+        #print image_patches.shape
+        return image_patches[0,...].astype(np.float32), mask_patches[0,...].astype(np.float32)
+
+    def random_patch_center_z(self, mask, patch_size):
+        # bounded within the brain mask region
+        limX, limY, limZ = np.where(mask>0)
+        if (np.min(limZ) + patch_size[2] // 2 + 1) < (np.max(limZ) - patch_size[2] // 2):
+            z = np.random.randint(low = np.min(limZ) + patch_size[2] // 2 + 1, high = np.max(limZ) - patch_size[2] // 2)
+        else:
+            z = np.random.randint(low = patchsize[2]//2, high = mask.shape[2] - patchsize[2]//2)
+
+        limX, limY, limZ = np.where(mask>0)
+
+        z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2))
+        # z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2))
+
+        return z
+
+    def parse_fn(self, data_path):
+        '''
+        :param image_path: path to a folder of a patient
+        :return: normalized entire image with its corresponding label
+        In an image, the air region is 0, so we only calculate the mean and std within the brain area
+        For any image-level normalization, do it here
+        '''
+        path = data_path.split(",")
+        image_path = path[0]
+        label_path = path[1]
+        #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
+        #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
+        itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
+        itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
+        # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
+
+        image = sitk.GetArrayFromImage(itk_image)
+        mask = sitk.GetArrayFromImage(itk_mask)
+        #image[image >= 1000] = 1000
+        binary_mask = np.ones(mask.shape)
+        mean = np.sum(image * binary_mask) / np.sum(binary_mask)
+        std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
+        image = (image - mean) / std  # normalize per image, using statistics within the brain, but apply to whole image
+
+        mask[mask==2] = 1
+
+        return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the
+
+
+    def _label_decomp(self, label_vol, num_cls):
+        """
+        decompose label for softmax classifier
+        original labels are batchsize * W * H * 1, with label values 0,1,2,3...
+        this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension
+        numpy version of tf.one_hot
+        """
+        one_hot = []
+        for i in xrange(num_cls):
+            _vol = np.zeros(label_vol.shape)
+            _vol[label_vol == i] = 1
+            one_hot.append(_vol)
+
+        return np.stack(one_hot, axis=-1)
+    # def augment(self, x):
+    #     # add more types of augmentations here
+    #     augmentations = [self.flip]
+    #     for f in augmentations:
+    #         x = tf.cond(tf.random_uniform([], 0, 1) < 0.25, lambda: f(x), lambda: x)
+            
+    #     return x
+
+    # def flip(self, x):
+    #     """Flip augmentation
+    #     Args:
+    #         x: Image to flip
+    #     Returns:
+    #         Augmented image
+    #     """
+    #     x = tf.image.random_flip_left_right(x)
+    #     # x = tf.image.random_flip_up_down(x)
+
+    #     return x
+