Diff of /data_generator.py [000000] .. [7b5b9f]

Switch to unified view

a b/data_generator.py
1
import tensorflow as tf
2
import numpy as np
3
import os
4
# from matplotlib import pyplot as plt
5
from tensorflow.python.framework import dtypes
6
from tensorflow.python.framework.ops import convert_to_tensor
7
import skimage as sk
8
from skimage import transform
9
import SimpleITK as sitk
10
11
IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32)
12
13
14
class ImageDataGenerator(object):
15
16
    def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, buffer_size=5):
17
18
        """Create a new ImageDataGenerator.
19
        Receives a path string to a text file, where each line has a path string to an image and
20
        separated by a space, then with an integer referring to the class number.
21
22
        Args:
23
            txt_file: path to the text file.
24
            mode: either 'training' or 'validation'. Depending on this value, different parsing functions will be used.
25
            batch_size: number of images per batch.
26
            num_classes: number of classes in the dataset.
27
            shuffle: wether or not to shuffle the data in the dataset and the initial file list.
28
            buffer_size: number of images used as buffer for TensorFlows shuffling of the dataset.
29
30
        Raises:
31
            ValueError: If an invalid mode is passed.
32
        """
33
34
        self.txt_file = txt_file
35
        self.num_classes = num_classes
36
37
        # retrieve the data from the text file
38
        self._read_txt_file()
39
40
        # number of samples in the dataset
41
        self.data_size = len(self.img_paths)
42
43
        # initial shuffling of the file and label lists together
44
        if shuffle:
45
            self._shuffle_lists()
46
47
        # convert lists to TF tensor
48
        self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string)
49
50
        # create dataset
51
        data = tf.data.Dataset.from_tensor_slices((self.img_paths))
52
53
        # repeat indefinitely (train.py will count the epochs)
54
        data = data.repeat()
55
56
        # distinguish between train/infer. when calling the parsing functions
57
        self.get_patches_fn = lambda filename: tf.py_func(self.extract_patch, [filename, [384,384,3], 2], [tf.float32, tf.float32])
58
59
        if mode == 'training':
60
            data = data.map(self.get_patches_fn, num_parallel_calls=8)
61
62
        elif mode == 'inference':
63
            data = data.map(self._parse_function_inference, num_parallel_calls=8)
64
65
        else:
66
            raise ValueError("Invalid mode '%s'." % (mode))
67
68
        # shuffle the first `buffer_size` elements of the dataset
69
        if shuffle:
70
            data = data.shuffle(buffer_size=buffer_size)
71
72
        # create a new dataset with batches of images
73
        data = data.batch(batch_size)
74
75
        self.data = data
76
77
    def _read_txt_file(self):
78
        """Read the content of the text file and store it into lists."""
79
        with open(self.txt_file, 'r') as f:
80
            rows = f.readlines()
81
            self.img_paths = [row[:-1] for row in rows]
82
83
    def _shuffle_lists(self):
84
        """Conjoined shuffling of the list of paths and labels."""
85
        path = self.img_paths
86
        permutation = np.random.permutation(self.data_size)
87
        self.img_paths = []
88
        for i in permutation:
89
            self.img_paths.append(path[i])
90
91
    def extract_patch(self, filename, patch_size, num_class, num_patches=1):
92
        """Input parser for samples of the training set."""
93
        # convert label number into one-hot-encoding
94
95
        image, mask = self.parse_fn(filename) # get the image and its mask
96
        image_patches = []
97
        mask_patches = []
98
        num_patches_now = 0
99
100
        while num_patches_now < num_patches:
101
            # z = np.random.randint(1, mask.shape[2]-1)
102
            z = self.random_patch_center_z(mask, patch_size=patch_size) # define the centre of current patch
103
            image_patch = image[:, :, z-1:z+2]
104
            mask_patch  =  mask[:, :, z]
105
            
106
            image_patches.append(image_patch)
107
            mask_patches.append(mask_patch)
108
            num_patches_now += 1
109
        image_patches = np.stack(image_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2])
110
        mask_patches = np.stack(mask_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2])
111
112
        mask_patches = self._label_decomp(mask_patches, num_cls=num_class) # make into 5D (batch_size, patch_size[0], patch_size[1], patch_size[2], num_classes)
113
        #print image_patches.shape
114
        return image_patches[0,...].astype(np.float32), mask_patches[0,...].astype(np.float32)
115
116
    def random_patch_center_z(self, mask, patch_size):
117
        # bounded within the brain mask region
118
        limX, limY, limZ = np.where(mask>0)
119
        if (np.min(limZ) + patch_size[2] // 2 + 1) < (np.max(limZ) - patch_size[2] // 2):
120
            z = np.random.randint(low = np.min(limZ) + patch_size[2] // 2 + 1, high = np.max(limZ) - patch_size[2] // 2)
121
        else:
122
            z = np.random.randint(low = patchsize[2]//2, high = mask.shape[2] - patchsize[2]//2)
123
124
        limX, limY, limZ = np.where(mask>0)
125
126
        z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2))
127
        # z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2))
128
129
        return z
130
131
    def parse_fn(self, data_path):
132
        '''
133
        :param image_path: path to a folder of a patient
134
        :return: normalized entire image with its corresponding label
135
        In an image, the air region is 0, so we only calculate the mean and std within the brain area
136
        For any image-level normalization, do it here
137
        '''
138
        path = data_path.split(",")
139
        image_path = path[0]
140
        label_path = path[1]
141
        #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
142
        #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
143
        itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
144
        itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
145
        # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
146
147
        image = sitk.GetArrayFromImage(itk_image)
148
        mask = sitk.GetArrayFromImage(itk_mask)
149
        #image[image >= 1000] = 1000
150
        binary_mask = np.ones(mask.shape)
151
        mean = np.sum(image * binary_mask) / np.sum(binary_mask)
152
        std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
153
        image = (image - mean) / std  # normalize per image, using statistics within the brain, but apply to whole image
154
155
        mask[mask==2] = 1
156
157
        return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the
158
159
160
    def _label_decomp(self, label_vol, num_cls):
161
        """
162
        decompose label for softmax classifier
163
        original labels are batchsize * W * H * 1, with label values 0,1,2,3...
164
        this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension
165
        numpy version of tf.one_hot
166
        """
167
        one_hot = []
168
        for i in xrange(num_cls):
169
            _vol = np.zeros(label_vol.shape)
170
            _vol[label_vol == i] = 1
171
            one_hot.append(_vol)
172
173
        return np.stack(one_hot, axis=-1)
174
    # def augment(self, x):
175
    #     # add more types of augmentations here
176
    #     augmentations = [self.flip]
177
    #     for f in augmentations:
178
    #         x = tf.cond(tf.random_uniform([], 0, 1) < 0.25, lambda: f(x), lambda: x)
179
            
180
    #     return x
181
182
    # def flip(self, x):
183
    #     """Flip augmentation
184
    #     Args:
185
    #         x: Image to flip
186
    #     Returns:
187
    #         Augmented image
188
    #     """
189
    #     x = tf.image.random_flip_left_right(x)
190
    #     # x = tf.image.random_flip_up_down(x)
191
192
    #     return x
193