|
a |
|
b/dataloaders/MSISBI2015.py |
|
|
1 |
"""Functions for reading MSISBI2015 NRRD data.""" |
|
|
2 |
|
|
|
3 |
from __future__ import absolute_import |
|
|
4 |
from __future__ import division |
|
|
5 |
from __future__ import print_function |
|
|
6 |
|
|
|
7 |
import glob |
|
|
8 |
import math |
|
|
9 |
import os.path |
|
|
10 |
import pickle |
|
|
11 |
|
|
|
12 |
import matplotlib.pyplot |
|
|
13 |
from imageio import imwrite |
|
|
14 |
from scipy.ndimage import zoom |
|
|
15 |
from six.moves import xrange |
|
|
16 |
from skimage.measure import label, regionprops |
|
|
17 |
|
|
|
18 |
from utils.NII import * |
|
|
19 |
from utils.image_utils import crop, crop_center |
|
|
20 |
from utils.tfrecord_utils import * |
|
|
21 |
|
|
|
22 |
matplotlib.pyplot.ion() |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
class MSISBI2015(object): |
|
|
26 |
PROTOCOL_MAPPINGS = {'FLAIR': ['flair'], 'MPRAGE': ['mprage'], 'PD': ['pd'], 'T2': ['t2']} |
|
|
27 |
SET_TYPES = ['TRAIN', 'VAL', 'TEST'] |
|
|
28 |
|
|
|
29 |
class Options(object): |
|
|
30 |
def __init__(self): |
|
|
31 |
self.dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
32 |
self.numSamples = -1 |
|
|
33 |
self.partition = {'TRAIN': 0.7, 'VAL': 0.2, 'TEST': 0.1} |
|
|
34 |
self.useCrops = False |
|
|
35 |
self.cropType = 'random' # random or center |
|
|
36 |
self.numRandomCropsPerSlice = 5 |
|
|
37 |
self.onlyPatchesWithLesions = False |
|
|
38 |
self.rotations = 0 |
|
|
39 |
self.cropWidth = 128 |
|
|
40 |
self.cropHeight = 128 |
|
|
41 |
self.cache = False |
|
|
42 |
self.sliceResolution = None # format: HxW |
|
|
43 |
self.addInstanceNoise = False # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch |
|
|
44 |
self.filterProtocol = None # FLAIR, T1, T2 |
|
|
45 |
self.filterType = "train" # train or test |
|
|
46 |
self.axis = 'axial' # saggital, coronal or axial |
|
|
47 |
self.debug = False |
|
|
48 |
self.normalizationMethod = 'standardization' |
|
|
49 |
self.sliceStart = 0 |
|
|
50 |
self.sliceEnd = 155 |
|
|
51 |
self.format = "raw" # raw or aligned; If aligned, nii-files will be crawled and loaded |
|
|
52 |
self.skullStripping = True |
|
|
53 |
self.viewMapping = {'saggital': 2, 'coronal': 1, 'axial': 0} |
|
|
54 |
|
|
|
55 |
def __init__(self, options=Options()): |
|
|
56 |
self.options = options |
|
|
57 |
|
|
|
58 |
if options.cache and os.path.isfile(self.pckl_name()): |
|
|
59 |
f = open(self.pckl_name(), 'rb') |
|
|
60 |
tmp = pickle.load(f) |
|
|
61 |
f.close() |
|
|
62 |
self._epochs_completed = tmp._epochs_completed |
|
|
63 |
self._index_in_epoch = tmp._index_in_epoch |
|
|
64 |
self.patientsSplit = tmp.patients_split |
|
|
65 |
self.patients = tmp.patients |
|
|
66 |
self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name()) |
|
|
67 |
self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
68 |
self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
69 |
else: |
|
|
70 |
# Collect all patients |
|
|
71 |
self.patients = self._get_patients() |
|
|
72 |
self.patientsSplit = {} |
|
|
73 |
|
|
|
74 |
if not os.path.isfile(self.split_name()): |
|
|
75 |
_numPatients = len(self.patients) |
|
|
76 |
_ridx = numpy.random.permutation(_numPatients) |
|
|
77 |
|
|
|
78 |
_already_taken = 0 |
|
|
79 |
for split in self.options.partition.keys(): |
|
|
80 |
if self.options.partition[split] <= 1.0: |
|
|
81 |
num_patients_for_current_split = math.floor(self.options.partition[split] * _numPatients) |
|
|
82 |
else: |
|
|
83 |
num_patients_for_current_split = self.options.partition[split] |
|
|
84 |
|
|
|
85 |
if num_patients_for_current_split > (_numPatients - _already_taken): |
|
|
86 |
num_patients_for_current_split = _numPatients - _already_taken |
|
|
87 |
|
|
|
88 |
self.patientsSplit[split] = _ridx[_already_taken:_already_taken + num_patients_for_current_split] |
|
|
89 |
_already_taken += num_patients_for_current_split |
|
|
90 |
|
|
|
91 |
f = open(self.split_name(), 'wb') |
|
|
92 |
pickle.dump(self.patientsSplit, f) |
|
|
93 |
f.close() |
|
|
94 |
else: |
|
|
95 |
f = open(self.split_name(), 'rb') |
|
|
96 |
self.patientsSplit = pickle.load(f) |
|
|
97 |
f.close() |
|
|
98 |
|
|
|
99 |
self._create_numpy_arrays() |
|
|
100 |
|
|
|
101 |
self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
102 |
self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
103 |
|
|
|
104 |
if self.options.cache: |
|
|
105 |
write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name()) |
|
|
106 |
tmp = copy.copy(self) |
|
|
107 |
tmp._images = None |
|
|
108 |
tmp._labels = None |
|
|
109 |
tmp._sets = None |
|
|
110 |
f = open(self.pckl_name(), 'wb') |
|
|
111 |
pickle.dump(tmp, f) |
|
|
112 |
f.close() |
|
|
113 |
|
|
|
114 |
def _create_numpy_arrays(self): |
|
|
115 |
# Iterate over all patients and extract slices |
|
|
116 |
_images = [] |
|
|
117 |
_labels = [] |
|
|
118 |
_sets = [] |
|
|
119 |
for p, patient in enumerate(self.patients): |
|
|
120 |
if p in self.patientsSplit['TRAIN']: |
|
|
121 |
_set_of_current_patient = MSISBI2015.SET_TYPES.index('TRAIN') |
|
|
122 |
elif p in self.patientsSplit['VAL']: |
|
|
123 |
_set_of_current_patient = MSISBI2015.SET_TYPES.index('VAL') |
|
|
124 |
elif p in self.patientsSplit['TEST']: |
|
|
125 |
_set_of_current_patient = MSISBI2015.SET_TYPES.index('TEST') |
|
|
126 |
|
|
|
127 |
for n, nii_filename in enumerate(patient['filtered_files']): |
|
|
128 |
# try: |
|
|
129 |
_images_tmp, _labels_tmp = self.gather_data(patient, nii_filename) |
|
|
130 |
_images += _images_tmp |
|
|
131 |
_labels += _labels_tmp |
|
|
132 |
_sets += [_set_of_current_patient] * len(_images_tmp) |
|
|
133 |
|
|
|
134 |
self._images = numpy.array(_images).astype(numpy.float32) |
|
|
135 |
self._labels = numpy.array(_labels).astype(numpy.float32) |
|
|
136 |
if self._images.ndim < 4: |
|
|
137 |
self._images = numpy.expand_dims(self._images, 3) |
|
|
138 |
self._sets = numpy.array(_sets).astype(numpy.int32) |
|
|
139 |
|
|
|
140 |
def gather_data(self, patient, nii_filename): |
|
|
141 |
_images = [] |
|
|
142 |
_labels = [] |
|
|
143 |
|
|
|
144 |
nii, nii_seg, nii_skullmap = self.load_volume_and_groundtruth(nii_filename, patient) |
|
|
145 |
|
|
|
146 |
# Iterate over all slices and collect them |
|
|
147 |
# We only want to select in the range from 15 to 125 (in axial view) |
|
|
148 |
for s in xrange(self.options.sliceStart, min(self.options.sliceEnd, nii.num_slices_along_axis(self.options.axis))): |
|
|
149 |
if 0 < self.options.numSamples < len(_images): |
|
|
150 |
break |
|
|
151 |
|
|
|
152 |
slice_data = nii.get_slice(s, self.options.axis) |
|
|
153 |
slice_seg = nii_seg.get_slice(s, self.options.axis) |
|
|
154 |
|
|
|
155 |
# Skip the slice if it is "empty" |
|
|
156 |
if numpy.percentile(slice_data, 90) < 0.2: |
|
|
157 |
continue |
|
|
158 |
|
|
|
159 |
if self.options.sliceResolution is not None: |
|
|
160 |
# Pad withzeros to top and bottom, if the image is too small |
|
|
161 |
if slice_data.shape[0] < self.options.sliceResolution[0]: |
|
|
162 |
before_y = math.floor((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0) |
|
|
163 |
after_y = math.ceil((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0) |
|
|
164 |
if slice_data.shape[1] < self.options.sliceResolution[1]: |
|
|
165 |
before_x = math.floor((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0) |
|
|
166 |
after_x = math.ceil((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0) |
|
|
167 |
if slice_data.shape[0] < self.options.sliceResolution[0] or slice_data.shape[1] < self.options.sliceResolution[1]: |
|
|
168 |
slice_data = np.pad(slice_data, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0)) |
|
|
169 |
slice_seg = np.pad(slice_seg, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0)) |
|
|
170 |
slice_data = zoom(slice_data, float(self.options.sliceResolution[0]) / float( |
|
|
171 |
slice_data.shape[0])) |
|
|
172 |
slice_seg = zoom( |
|
|
173 |
slice_seg, float(self.options.sliceResolution[0]) / float(slice_seg.shape[0]), mode="nearest" |
|
|
174 |
) |
|
|
175 |
slice_seg[slice_seg < 0.9] = 0.0 |
|
|
176 |
slice_seg[slice_seg >= 0.9] = 1.0 |
|
|
177 |
# assert numpy.max(slice_data) <= 1.0, "Resized slice range is outside [0; 1]!" |
|
|
178 |
|
|
|
179 |
# Either collect crops |
|
|
180 |
if self.options.useCrops: |
|
|
181 |
if self.options.cropType == 'random': |
|
|
182 |
rx = numpy.random.randint(0, high=(slice_data.shape[1] - self.options.cropWidth), |
|
|
183 |
size=self.options.numRandomCropsPerSlice) |
|
|
184 |
ry = numpy.random.randint(0, high=(slice_data.shape[0] - self.options.cropHeight), |
|
|
185 |
size=self.options.numRandomCropsPerSlice) |
|
|
186 |
for r in range(self.options.numRandomCropsPerSlice): |
|
|
187 |
_images.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth)) |
|
|
188 |
_labels.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth)) |
|
|
189 |
elif self.options.cropType == 'center': |
|
|
190 |
slice_data_cropped = crop_center(slice_data, self.options.cropWidth, self.options.cropHeight) |
|
|
191 |
slice_seg_cropped = crop_center(slice_seg, self.options.cropWidth, self.options.cropHeight) |
|
|
192 |
_images.append(slice_data_cropped) |
|
|
193 |
_labels.append(slice_seg_cropped) |
|
|
194 |
elif self.options.cropType == 'lesions': |
|
|
195 |
cc_slice = label(slice_seg) |
|
|
196 |
props = regionprops(cc_slice) |
|
|
197 |
if len(props) > 0: |
|
|
198 |
for prop in props: |
|
|
199 |
cx = prop['centroid'][1] |
|
|
200 |
cy = prop['centroid'][0] |
|
|
201 |
if cy < self.options.cropHeight // 2: |
|
|
202 |
cy = self.options.cropHeight // 2 |
|
|
203 |
if cy > (slice_data.shape[0] - (self.options.cropHeight // 2)): |
|
|
204 |
cy = (slice_data.shape[0] - (self.options.cropHeight // 2)) |
|
|
205 |
if cx < self.options.cropWidth // 2: |
|
|
206 |
cx = self.options.cropWidth // 2 |
|
|
207 |
if cx > (slice_data.shape[1] - (self.options.cropWidth // 2)): |
|
|
208 |
cx = (slice_data.shape[1] - (self.options.cropWidth // 2)) |
|
|
209 |
image_crop = crop(slice_data, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2), |
|
|
210 |
self.options.cropHeight, self.options.cropWidth) |
|
|
211 |
seg_crop = crop(slice_seg, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2), |
|
|
212 |
self.options.cropHeight, self.options.cropWidth) |
|
|
213 |
if image_crop.shape[0] != self.options.cropHeight or image_crop.shape[1] != self.options.cropWidth: |
|
|
214 |
continue |
|
|
215 |
_images.append(image_crop) |
|
|
216 |
_labels.append(seg_crop) |
|
|
217 |
|
|
|
218 |
# Or whole slices |
|
|
219 |
else: |
|
|
220 |
_images.append(slice_data) |
|
|
221 |
_labels.append(slice_seg) |
|
|
222 |
|
|
|
223 |
return _images, _labels |
|
|
224 |
|
|
|
225 |
def load_volume_and_groundtruth(self, nii_filename, patient): |
|
|
226 |
# Load the nrrd |
|
|
227 |
try: |
|
|
228 |
nii = NII(nii_filename) |
|
|
229 |
nii_groundtruth = NII(patient['groundtruth']) |
|
|
230 |
|
|
|
231 |
nii.denoise() |
|
|
232 |
nii.set_view_mapping(self.options.viewMapping) |
|
|
233 |
except: |
|
|
234 |
print('MSISBI2015: Failed to open file ' + nii_filename) |
|
|
235 |
|
|
|
236 |
# Make sure ground-truth is binary and nrrd doesnt have NaNs |
|
|
237 |
nii.data[np.isnan(nii.data)] = 0.0 |
|
|
238 |
nii_groundtruth.data[nii_groundtruth.data < 0.9] = 0.0 |
|
|
239 |
nii_groundtruth.data[nii_groundtruth.data >= 0.9] = 1.0 |
|
|
240 |
|
|
|
241 |
# Do skull-stripping, if desired |
|
|
242 |
if self.options.skullStripping: |
|
|
243 |
try: |
|
|
244 |
nii_skullmap = NII(patient['skullmap']) |
|
|
245 |
nii_skullmap.set_view_mapping(self.options.viewMapping) |
|
|
246 |
nii.apply_skullmap(nii_skullmap) |
|
|
247 |
except: |
|
|
248 |
print('MSISBI2015: Failed to open file ' + patient['skullmap'] + ', skipping skullremoval') |
|
|
249 |
|
|
|
250 |
# In-place normalize the loaded volume |
|
|
251 |
nii.normalize(method=self.options.normalizationMethod, lowerpercentile=0, upperpercentile=99.8) |
|
|
252 |
# nii_skullmap.data = nii_skullmap.data > 0.0 |
|
|
253 |
|
|
|
254 |
return nii, nii_groundtruth, nii_skullmap |
|
|
255 |
|
|
|
256 |
# Hidden helper function, not supposed to be called from outside! |
|
|
257 |
def _get_patients(self): |
|
|
258 |
return MSISBI2015.get_patients(self.options) |
|
|
259 |
|
|
|
260 |
@staticmethod |
|
|
261 |
def get_patients(options): |
|
|
262 |
folders = ["training01", "training02", "training03", "training04", "training05"] |
|
|
263 |
|
|
|
264 |
# Iterate over all folders in folders and collect all patients |
|
|
265 |
patients = [] |
|
|
266 |
for f, folder in enumerate(folders): |
|
|
267 |
# Get all files that can be used for training and validation |
|
|
268 |
_patients = glob.glob(os.path.join(options.dir, folder, "preprocessed", folder + "_*_flair_pp.nii")) |
|
|
269 |
for p, pname in enumerate(_patients): |
|
|
270 |
patient = {} |
|
|
271 |
_tmp = os.path.normpath(pname).split(os.path.sep) |
|
|
272 |
patient['name'] = _tmp[-1].replace("_flair_pp.nii", "") |
|
|
273 |
patient['fullpath'] = os.path.join(options.dir, folder, "preprocessed") |
|
|
274 |
|
|
|
275 |
patient["filtered_files"] = [] |
|
|
276 |
for protocol, protocol_array in MSISBI2015.PROTOCOL_MAPPINGS.items(): |
|
|
277 |
if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols: |
|
|
278 |
continue |
|
|
279 |
else: |
|
|
280 |
if options.format == "raw": |
|
|
281 |
patient[protocol] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '_pp.nii') |
|
|
282 |
elif options.format == "aligned": |
|
|
283 |
patient[protocol] = os.path.join(options.dir, folder, "preprocessed", |
|
|
284 |
patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz') |
|
|
285 |
|
|
|
286 |
if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols: |
|
|
287 |
continue |
|
|
288 |
else: |
|
|
289 |
if options.format == "raw": |
|
|
290 |
patient["filtered_files"] += [ |
|
|
291 |
os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '_pp.nii')] |
|
|
292 |
elif options.format == "aligned": |
|
|
293 |
patient["filtered_files"] += [ |
|
|
294 |
os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz')] |
|
|
295 |
|
|
|
296 |
if options.format == "raw": |
|
|
297 |
patient['groundtruth'] = os.path.join(options.dir, folder, "masks", patient['name'] + "_mask1.nii") |
|
|
298 |
patient['skullmap'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_skullmap.nii.gz") |
|
|
299 |
elif options.format == "aligned": |
|
|
300 |
patient['groundtruth'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_mask1.aligned.nii.gz") |
|
|
301 |
patient['skullmap'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_skullmap.aligned.nii.gz") |
|
|
302 |
|
|
|
303 |
# Append to the list of all patients |
|
|
304 |
patients.append(patient) |
|
|
305 |
|
|
|
306 |
return patients |
|
|
307 |
|
|
|
308 |
# Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice |
|
|
309 |
def get_patient_idx(self, split='TRAIN'): |
|
|
310 |
return self.patientsSplit[split] |
|
|
311 |
|
|
|
312 |
def get_patient_split(self): |
|
|
313 |
return self.patientsSplit |
|
|
314 |
|
|
|
315 |
@property |
|
|
316 |
def images(self): |
|
|
317 |
return self._images |
|
|
318 |
|
|
|
319 |
def get_images(self, set=None): |
|
|
320 |
_setIdx = MSISBI2015.SET_TYPES.index(set) |
|
|
321 |
images_in_set = numpy.where(self._sets == _setIdx)[0] |
|
|
322 |
return self._images[images_in_set] |
|
|
323 |
|
|
|
324 |
def get_image(self, i): |
|
|
325 |
return self._images[i, :, :, :] |
|
|
326 |
|
|
|
327 |
def get_label(self, i): |
|
|
328 |
return self._labels[i, :, :, :] |
|
|
329 |
|
|
|
330 |
def get_patient(self, i): |
|
|
331 |
return self.patients[i] |
|
|
332 |
|
|
|
333 |
@property |
|
|
334 |
def labels(self): |
|
|
335 |
return self._labels |
|
|
336 |
|
|
|
337 |
@property |
|
|
338 |
def sets(self): |
|
|
339 |
return self._sets |
|
|
340 |
|
|
|
341 |
@property |
|
|
342 |
def meta(self): |
|
|
343 |
return self._meta |
|
|
344 |
|
|
|
345 |
@property |
|
|
346 |
def num_examples(self): |
|
|
347 |
return self._images.shape[0] |
|
|
348 |
|
|
|
349 |
@property |
|
|
350 |
def width(self): |
|
|
351 |
return self._images.shape[2] |
|
|
352 |
|
|
|
353 |
@property |
|
|
354 |
def height(self): |
|
|
355 |
return self._images.shape[1] |
|
|
356 |
|
|
|
357 |
@property |
|
|
358 |
def num_channels(self): |
|
|
359 |
return self._images.shape[3] |
|
|
360 |
|
|
|
361 |
@property |
|
|
362 |
def epochs_completed(self): |
|
|
363 |
return self._epochs_completed |
|
|
364 |
|
|
|
365 |
def name(self): |
|
|
366 |
_name = "MSISBI2015" |
|
|
367 |
if self.options.numSamples > 0: |
|
|
368 |
_name += '_n{}'.format(self.options.numSamples) |
|
|
369 |
_name += "_p{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL']) |
|
|
370 |
if self.options.useCrops: |
|
|
371 |
_name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight) |
|
|
372 |
if self.options.cropType == "random": |
|
|
373 |
_name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice) |
|
|
374 |
if self.options.sliceResolution is not None: |
|
|
375 |
_name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1]) |
|
|
376 |
_name += "_{}".format(self.options.format) |
|
|
377 |
return _name |
|
|
378 |
|
|
|
379 |
def split_name(self): |
|
|
380 |
return os.path.join(self.dir(), 'split-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL'])) |
|
|
381 |
|
|
|
382 |
def pckl_name(self): |
|
|
383 |
return os.path.join(self.dir(), self.name() + ".pckl") |
|
|
384 |
|
|
|
385 |
def tfrecord_name(self): |
|
|
386 |
return os.path.join(self.dir(), self.name() + ".tfrecord") |
|
|
387 |
|
|
|
388 |
def dir(self): |
|
|
389 |
return self.options.dir |
|
|
390 |
|
|
|
391 |
def export_slices(self, dir): |
|
|
392 |
for i in range(self.num_examples): |
|
|
393 |
imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8')) |
|
|
394 |
|
|
|
395 |
def visualize(self, pause=1): |
|
|
396 |
f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2) |
|
|
397 |
images_tmp, labels_tmp, _ = self.next_batch(10) |
|
|
398 |
for i in range(images_tmp.shape[0]): |
|
|
399 |
img = numpy.squeeze(images_tmp[i]) |
|
|
400 |
lbl = numpy.squeeze(labels_tmp[i]) |
|
|
401 |
ax1.imshow(img) |
|
|
402 |
ax1.set_title('Patch') |
|
|
403 |
ax2.imshow(lbl) |
|
|
404 |
ax2.set_title('Groundtruth') |
|
|
405 |
matplotlib.pyplot.pause(pause) |
|
|
406 |
|
|
|
407 |
def num_batches(self, batchsize, set='TRAIN'): |
|
|
408 |
_setIdx = MSISBI2015.SET_TYPES.index(set) |
|
|
409 |
images_in_set = numpy.where(self._sets == _setIdx)[0] |
|
|
410 |
return len(images_in_set) // batchsize |
|
|
411 |
|
|
|
412 |
def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=True): |
|
|
413 |
"""Return the next `batch_size` examples from this data set.""" |
|
|
414 |
_setIdx = MSISBI2015.SET_TYPES.index(set) |
|
|
415 |
images_in_set = numpy.where(self._sets == _setIdx)[0] |
|
|
416 |
samples_in_set = len(images_in_set) |
|
|
417 |
|
|
|
418 |
start = self._index_in_epoch[set] |
|
|
419 |
# Shuffle for the first epoch |
|
|
420 |
if self._epochs_completed[set] == 0 and start == 0 and shuffle: |
|
|
421 |
perm0 = numpy.arange(samples_in_set) |
|
|
422 |
numpy.random.shuffle(perm0) |
|
|
423 |
self._images[images_in_set] = self.images[images_in_set[perm0]] |
|
|
424 |
self._labels[images_in_set] = self.labels[images_in_set[perm0]] |
|
|
425 |
self._sets[images_in_set] = self.sets[images_in_set[perm0]] |
|
|
426 |
|
|
|
427 |
# Go to the next epoch |
|
|
428 |
if start + batch_size > samples_in_set: |
|
|
429 |
# Finished epoch |
|
|
430 |
self._epochs_completed[set] += 1 |
|
|
431 |
|
|
|
432 |
# Get the rest examples in this epoch |
|
|
433 |
rest_num_examples = samples_in_set - start |
|
|
434 |
images_rest_part = self._images[images_in_set[start:samples_in_set]] |
|
|
435 |
labels_rest_part = self._labels[images_in_set[start:samples_in_set]] |
|
|
436 |
|
|
|
437 |
# Shuffle the data |
|
|
438 |
if shuffle: |
|
|
439 |
perm = numpy.arange(samples_in_set) |
|
|
440 |
numpy.random.shuffle(perm) |
|
|
441 |
self._images[images_in_set] = self.images[images_in_set[perm]] |
|
|
442 |
self._labels[images_in_set] = self.labels[images_in_set[perm]] |
|
|
443 |
self._sets[images_in_set] = self.sets[images_in_set[perm]] |
|
|
444 |
|
|
|
445 |
# Start next epoch |
|
|
446 |
start = 0 |
|
|
447 |
self._index_in_epoch[set] = batch_size - rest_num_examples |
|
|
448 |
end = self._index_in_epoch[set] |
|
|
449 |
images_new_part = self._images[images_in_set[start:end]] |
|
|
450 |
labels_new_part = self._labels[images_in_set[start:end]] |
|
|
451 |
|
|
|
452 |
images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0) |
|
|
453 |
labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0) |
|
|
454 |
else: |
|
|
455 |
self._index_in_epoch[set] += batch_size |
|
|
456 |
end = self._index_in_epoch[set] |
|
|
457 |
images_tmp = self._images[images_in_set[start:end]] |
|
|
458 |
labels_tmp = self._labels[images_in_set[start:end]] |
|
|
459 |
|
|
|
460 |
if self.options.addInstanceNoise: |
|
|
461 |
noise = numpy.random.normal(0, 0.01, images_tmp.shape) |
|
|
462 |
images_tmp += noise |
|
|
463 |
|
|
|
464 |
# Check the batch |
|
|
465 |
assert images_tmp.size, "The batch is empty!" |
|
|
466 |
assert labels_tmp.size, "The labels of the current batch are empty!" |
|
|
467 |
|
|
|
468 |
if return_brainmask: |
|
|
469 |
brainmasks = images_tmp > 0.05 |
|
|
470 |
else: |
|
|
471 |
brainmasks = None |
|
|
472 |
|
|
|
473 |
return images_tmp, labels_tmp, brainmasks |