|
a |
|
b/data_generator.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import numpy as np |
|
|
3 |
import os |
|
|
4 |
# from matplotlib import pyplot as plt |
|
|
5 |
from tensorflow.python.framework import dtypes |
|
|
6 |
from tensorflow.python.framework.ops import convert_to_tensor |
|
|
7 |
import skimage as sk |
|
|
8 |
from skimage import transform |
|
|
9 |
import SimpleITK as sitk |
|
|
10 |
|
|
|
11 |
IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32) |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class ImageDataGenerator(object): |
|
|
15 |
|
|
|
16 |
def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, buffer_size=5): |
|
|
17 |
|
|
|
18 |
"""Create a new ImageDataGenerator. |
|
|
19 |
Receives a path string to a text file, where each line has a path string to an image and |
|
|
20 |
separated by a space, then with an integer referring to the class number. |
|
|
21 |
|
|
|
22 |
Args: |
|
|
23 |
txt_file: path to the text file. |
|
|
24 |
mode: either 'training' or 'validation'. Depending on this value, different parsing functions will be used. |
|
|
25 |
batch_size: number of images per batch. |
|
|
26 |
num_classes: number of classes in the dataset. |
|
|
27 |
shuffle: wether or not to shuffle the data in the dataset and the initial file list. |
|
|
28 |
buffer_size: number of images used as buffer for TensorFlows shuffling of the dataset. |
|
|
29 |
|
|
|
30 |
Raises: |
|
|
31 |
ValueError: If an invalid mode is passed. |
|
|
32 |
""" |
|
|
33 |
|
|
|
34 |
self.txt_file = txt_file |
|
|
35 |
self.num_classes = num_classes |
|
|
36 |
|
|
|
37 |
# retrieve the data from the text file |
|
|
38 |
self._read_txt_file() |
|
|
39 |
|
|
|
40 |
# number of samples in the dataset |
|
|
41 |
self.data_size = len(self.img_paths) |
|
|
42 |
|
|
|
43 |
# initial shuffling of the file and label lists together |
|
|
44 |
if shuffle: |
|
|
45 |
self._shuffle_lists() |
|
|
46 |
|
|
|
47 |
# convert lists to TF tensor |
|
|
48 |
self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string) |
|
|
49 |
|
|
|
50 |
# create dataset |
|
|
51 |
data = tf.data.Dataset.from_tensor_slices((self.img_paths)) |
|
|
52 |
|
|
|
53 |
# repeat indefinitely (train.py will count the epochs) |
|
|
54 |
data = data.repeat() |
|
|
55 |
|
|
|
56 |
# distinguish between train/infer. when calling the parsing functions |
|
|
57 |
self.get_patches_fn = lambda filename: tf.py_func(self.extract_patch, [filename, [384,384,3], 2], [tf.float32, tf.float32]) |
|
|
58 |
|
|
|
59 |
if mode == 'training': |
|
|
60 |
data = data.map(self.get_patches_fn, num_parallel_calls=8) |
|
|
61 |
|
|
|
62 |
elif mode == 'inference': |
|
|
63 |
data = data.map(self._parse_function_inference, num_parallel_calls=8) |
|
|
64 |
|
|
|
65 |
else: |
|
|
66 |
raise ValueError("Invalid mode '%s'." % (mode)) |
|
|
67 |
|
|
|
68 |
# shuffle the first `buffer_size` elements of the dataset |
|
|
69 |
if shuffle: |
|
|
70 |
data = data.shuffle(buffer_size=buffer_size) |
|
|
71 |
|
|
|
72 |
# create a new dataset with batches of images |
|
|
73 |
data = data.batch(batch_size) |
|
|
74 |
|
|
|
75 |
self.data = data |
|
|
76 |
|
|
|
77 |
def _read_txt_file(self): |
|
|
78 |
"""Read the content of the text file and store it into lists.""" |
|
|
79 |
with open(self.txt_file, 'r') as f: |
|
|
80 |
rows = f.readlines() |
|
|
81 |
self.img_paths = [row[:-1] for row in rows] |
|
|
82 |
|
|
|
83 |
def _shuffle_lists(self): |
|
|
84 |
"""Conjoined shuffling of the list of paths and labels.""" |
|
|
85 |
path = self.img_paths |
|
|
86 |
permutation = np.random.permutation(self.data_size) |
|
|
87 |
self.img_paths = [] |
|
|
88 |
for i in permutation: |
|
|
89 |
self.img_paths.append(path[i]) |
|
|
90 |
|
|
|
91 |
def extract_patch(self, filename, patch_size, num_class, num_patches=1): |
|
|
92 |
"""Input parser for samples of the training set.""" |
|
|
93 |
# convert label number into one-hot-encoding |
|
|
94 |
|
|
|
95 |
image, mask = self.parse_fn(filename) # get the image and its mask |
|
|
96 |
image_patches = [] |
|
|
97 |
mask_patches = [] |
|
|
98 |
num_patches_now = 0 |
|
|
99 |
|
|
|
100 |
while num_patches_now < num_patches: |
|
|
101 |
# z = np.random.randint(1, mask.shape[2]-1) |
|
|
102 |
z = self.random_patch_center_z(mask, patch_size=patch_size) # define the centre of current patch |
|
|
103 |
image_patch = image[:, :, z-1:z+2] |
|
|
104 |
mask_patch = mask[:, :, z] |
|
|
105 |
|
|
|
106 |
image_patches.append(image_patch) |
|
|
107 |
mask_patches.append(mask_patch) |
|
|
108 |
num_patches_now += 1 |
|
|
109 |
image_patches = np.stack(image_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2]) |
|
|
110 |
mask_patches = np.stack(mask_patches) # make into 4D (batch_size, patch_size[0], patch_size[1], patch_size[2]) |
|
|
111 |
|
|
|
112 |
mask_patches = self._label_decomp(mask_patches, num_cls=num_class) # make into 5D (batch_size, patch_size[0], patch_size[1], patch_size[2], num_classes) |
|
|
113 |
#print image_patches.shape |
|
|
114 |
return image_patches[0,...].astype(np.float32), mask_patches[0,...].astype(np.float32) |
|
|
115 |
|
|
|
116 |
def random_patch_center_z(self, mask, patch_size): |
|
|
117 |
# bounded within the brain mask region |
|
|
118 |
limX, limY, limZ = np.where(mask>0) |
|
|
119 |
if (np.min(limZ) + patch_size[2] // 2 + 1) < (np.max(limZ) - patch_size[2] // 2): |
|
|
120 |
z = np.random.randint(low = np.min(limZ) + patch_size[2] // 2 + 1, high = np.max(limZ) - patch_size[2] // 2) |
|
|
121 |
else: |
|
|
122 |
z = np.random.randint(low = patchsize[2]//2, high = mask.shape[2] - patchsize[2]//2) |
|
|
123 |
|
|
|
124 |
limX, limY, limZ = np.where(mask>0) |
|
|
125 |
|
|
|
126 |
z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2)) |
|
|
127 |
# z = np.random.randint(low = max(1, np.min(limZ)), high = min(np.max(limZ), mask.shape[2] - 2)) |
|
|
128 |
|
|
|
129 |
return z |
|
|
130 |
|
|
|
131 |
def parse_fn(self, data_path): |
|
|
132 |
''' |
|
|
133 |
:param image_path: path to a folder of a patient |
|
|
134 |
:return: normalized entire image with its corresponding label |
|
|
135 |
In an image, the air region is 0, so we only calculate the mean and std within the brain area |
|
|
136 |
For any image-level normalization, do it here |
|
|
137 |
''' |
|
|
138 |
path = data_path.split(",") |
|
|
139 |
image_path = path[0] |
|
|
140 |
label_path = path[1] |
|
|
141 |
#itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) |
|
|
142 |
#itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) |
|
|
143 |
itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz')) |
|
|
144 |
itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz')) |
|
|
145 |
# itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz')) |
|
|
146 |
|
|
|
147 |
image = sitk.GetArrayFromImage(itk_image) |
|
|
148 |
mask = sitk.GetArrayFromImage(itk_mask) |
|
|
149 |
#image[image >= 1000] = 1000 |
|
|
150 |
binary_mask = np.ones(mask.shape) |
|
|
151 |
mean = np.sum(image * binary_mask) / np.sum(binary_mask) |
|
|
152 |
std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask)) |
|
|
153 |
image = (image - mean) / std # normalize per image, using statistics within the brain, but apply to whole image |
|
|
154 |
|
|
|
155 |
mask[mask==2] = 1 |
|
|
156 |
|
|
|
157 |
return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
def _label_decomp(self, label_vol, num_cls): |
|
|
161 |
""" |
|
|
162 |
decompose label for softmax classifier |
|
|
163 |
original labels are batchsize * W * H * 1, with label values 0,1,2,3... |
|
|
164 |
this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension |
|
|
165 |
numpy version of tf.one_hot |
|
|
166 |
""" |
|
|
167 |
one_hot = [] |
|
|
168 |
for i in xrange(num_cls): |
|
|
169 |
_vol = np.zeros(label_vol.shape) |
|
|
170 |
_vol[label_vol == i] = 1 |
|
|
171 |
one_hot.append(_vol) |
|
|
172 |
|
|
|
173 |
return np.stack(one_hot, axis=-1) |
|
|
174 |
# def augment(self, x): |
|
|
175 |
# # add more types of augmentations here |
|
|
176 |
# augmentations = [self.flip] |
|
|
177 |
# for f in augmentations: |
|
|
178 |
# x = tf.cond(tf.random_uniform([], 0, 1) < 0.25, lambda: f(x), lambda: x) |
|
|
179 |
|
|
|
180 |
# return x |
|
|
181 |
|
|
|
182 |
# def flip(self, x): |
|
|
183 |
# """Flip augmentation |
|
|
184 |
# Args: |
|
|
185 |
# x: Image to flip |
|
|
186 |
# Returns: |
|
|
187 |
# Augmented image |
|
|
188 |
# """ |
|
|
189 |
# x = tf.image.random_flip_left_right(x) |
|
|
190 |
# # x = tf.image.random_flip_up_down(x) |
|
|
191 |
|
|
|
192 |
# return x |
|
|
193 |
|