|
a |
|
b/dataloaders/BRAINWEB.py |
|
|
1 |
"""Functions for reading BRAINWEB NII data.""" |
|
|
2 |
|
|
|
3 |
from __future__ import absolute_import |
|
|
4 |
from __future__ import division |
|
|
5 |
from __future__ import print_function |
|
|
6 |
|
|
|
7 |
import glob |
|
|
8 |
import math |
|
|
9 |
import os.path |
|
|
10 |
import pickle |
|
|
11 |
|
|
|
12 |
import cv2 |
|
|
13 |
import matplotlib.pyplot |
|
|
14 |
from imageio import imwrite |
|
|
15 |
from scipy.ndimage import rotate |
|
|
16 |
|
|
|
17 |
from utils.MINC import * |
|
|
18 |
from utils.image_utils import crop, crop_center |
|
|
19 |
from utils.tfrecord_utils import * |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
class BRAINWEB(object): |
|
|
23 |
FILTER_TYPES = ['NORMAL', 'MILDMS', 'MODERATEMS', 'SEVEREMS'] |
|
|
24 |
SET_TYPES = ['TRAIN', 'VAL', 'TEST'] |
|
|
25 |
LABELS = {'BACKGROUND': 0, 'CSF': 1, 'GM': 2, 'WM': 3, 'FAT': 4, 'MUSCLE': 5, 'SKIN': 6, 'SKULL': 7, 'GLIALMATTER': 8, 'CONNECTIVE': 9, 'LESION': 10} |
|
|
26 |
VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2} |
|
|
27 |
PROTOCOL_MAPPINGS = {'FLAIR': 'flair*', 'T2': 't2*'} |
|
|
28 |
|
|
|
29 |
class Options(object): |
|
|
30 |
def __init__(self): |
|
|
31 |
self.description = None |
|
|
32 |
self.dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
33 |
self.folderNormal = 'normal' |
|
|
34 |
self.folderMildMS = os.path.join('lesions', 'mild') |
|
|
35 |
self.folderModerateMS = os.path.join('lesions', 'moderate') |
|
|
36 |
self.folderSevereMS = os.path.join('lesions', 'severe') |
|
|
37 |
self.folderGT = 'groundtruth' |
|
|
38 |
self.numSamples = -1 |
|
|
39 |
self.partition = {'TRAIN': 0.6, 'VAL': 0.15, 'TEST': 0.25} |
|
|
40 |
self.sliceStart = 20 |
|
|
41 |
self.sliceEnd = 140 |
|
|
42 |
self.useCrops = False |
|
|
43 |
self.cropType = 'random' # random or center |
|
|
44 |
self.numRandomCropsPerSlice = 5 |
|
|
45 |
self.rotations = [0] |
|
|
46 |
self.cropWidth = 128 |
|
|
47 |
self.cropHeight = 128 |
|
|
48 |
self.cache = False |
|
|
49 |
self.sliceResolution = None # format: HxW |
|
|
50 |
self.addInstanceNoise = False # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch |
|
|
51 |
self.filterProtocol = None # T2 or FLAIR only, not implemented for now |
|
|
52 |
self.filterType = None # MILDMS, MODERATEMS, SEVEREMS, NORMAL |
|
|
53 |
self.axis = 'axial' # saggital, coronal or axial |
|
|
54 |
self.debug = False |
|
|
55 |
self.normalizationMethod = 'standardization' |
|
|
56 |
self.skullRemoval = False |
|
|
57 |
self.backgroundRemoval = False |
|
|
58 |
|
|
|
59 |
def __init__(self, options=Options()): |
|
|
60 |
self.options = options |
|
|
61 |
|
|
|
62 |
if options.cache and os.path.isfile(self.pckl_name()): |
|
|
63 |
f = open(self.pckl_name(), 'rb') |
|
|
64 |
tmp = pickle.load(f) |
|
|
65 |
f.close() |
|
|
66 |
self._epochs_completed = tmp._epochs_completed |
|
|
67 |
self._index_in_epoch = tmp._index_in_epoch |
|
|
68 |
self.patients = self._get_patients() |
|
|
69 |
self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name()) |
|
|
70 |
|
|
|
71 |
f = open(self.split_name(), 'rb') |
|
|
72 |
self.patients_split = pickle.load(f) |
|
|
73 |
f.close() |
|
|
74 |
if not os.path.exists(self.split_name() + ".deprecated"): |
|
|
75 |
os.rename(self.split_name(), self.split_name() + ".deprecated") |
|
|
76 |
self._convert_patient_split() |
|
|
77 |
|
|
|
78 |
self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
79 |
self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
80 |
else: |
|
|
81 |
# Collect all patients |
|
|
82 |
self.patients = self._get_patients() |
|
|
83 |
self.patients_split = {} # Here we will later store the info whether a patient belongs to train, val or test |
|
|
84 |
|
|
|
85 |
# Determine Train, Val & Test set based on patients |
|
|
86 |
if not os.path.isfile(self.split_name()): |
|
|
87 |
_num_patients = len(self.patients) |
|
|
88 |
_ridx = numpy.random.permutation(_num_patients) |
|
|
89 |
|
|
|
90 |
_already_taken = 0 |
|
|
91 |
for split in self.options.partition.keys(): |
|
|
92 |
if 1.0 >= self.options.partition[split] > 0.0: |
|
|
93 |
num_patients_for_current_split = max(1, math.floor(self.options.partition[split] * _num_patients)) |
|
|
94 |
else: |
|
|
95 |
num_patients_for_current_split = int(self.options.partition[split]) |
|
|
96 |
|
|
|
97 |
if num_patients_for_current_split > (_num_patients - _already_taken): |
|
|
98 |
num_patients_for_current_split = _num_patients - _already_taken |
|
|
99 |
|
|
|
100 |
self.patients_split[split] = _ridx[_already_taken:_already_taken + num_patients_for_current_split] |
|
|
101 |
_already_taken += num_patients_for_current_split |
|
|
102 |
|
|
|
103 |
self._convert_patient_split() # NEW! We have a new format for storing hte patientsSplit which is OS agnostic. |
|
|
104 |
else: |
|
|
105 |
f = open(self.split_name(), 'rb') |
|
|
106 |
self.patients_split = pickle.load(f) |
|
|
107 |
f.close() |
|
|
108 |
self._convert_patient_split() # NEW! We have a new format for storing hte patientsSplit which is OS agnostic. |
|
|
109 |
|
|
|
110 |
# Iterate over all patients and the filtered NII files and extract slices |
|
|
111 |
_images = [] |
|
|
112 |
_labels = [] |
|
|
113 |
_sets = [] |
|
|
114 |
for p, patient in enumerate(self.patients): |
|
|
115 |
if patient["name"] in self.patients_split['TRAIN']: |
|
|
116 |
_set_of_current_patient = BRAINWEB.SET_TYPES.index('TRAIN') |
|
|
117 |
elif patient["name"] in self.patients_split['VAL']: |
|
|
118 |
_set_of_current_patient = BRAINWEB.SET_TYPES.index('VAL') |
|
|
119 |
elif patient["name"] in self.patients_split['TEST']: |
|
|
120 |
_set_of_current_patient = BRAINWEB.SET_TYPES.index('TEST') |
|
|
121 |
|
|
|
122 |
minc, minc_seg, minc_skullmap = self.load_volume_and_groundtruth(patient["filtered_files"][0], patient) |
|
|
123 |
|
|
|
124 |
# Iterate over all slices and collect them |
|
|
125 |
for s in range(self.options.sliceStart, min(self.options.sliceEnd, minc.num_slices_along_axis(self.options.axis))): |
|
|
126 |
if 0 < self.options.numSamples < len(_images): |
|
|
127 |
break |
|
|
128 |
|
|
|
129 |
slice_data = minc.get_slice(s, self.options.axis) |
|
|
130 |
slice_seg = minc_seg.get_slice(s, self.options.axis) |
|
|
131 |
|
|
|
132 |
# Skip the slice if it is entirely black |
|
|
133 |
if numpy.unique(slice_data).size == 1: |
|
|
134 |
continue |
|
|
135 |
|
|
|
136 |
# assert numpy.max(slice_data) <= 1.0, "Slice range is outside [0; 1]!" |
|
|
137 |
|
|
|
138 |
if self.options.sliceResolution is not None: |
|
|
139 |
# If the images are too big in resolution, do downsampling |
|
|
140 |
if slice_data.shape[0] > self.options.sliceResolution[0] or slice_data.shape[1] > self.options.sliceResolution[1]: |
|
|
141 |
slice_data = cv2.resize(slice_data, tuple(self.options.sliceResolution)) |
|
|
142 |
slice_seg = cv2.resize(slice_seg, tuple(self.options.sliceResolution), interpolation=cv2.INTER_NEAREST) |
|
|
143 |
# Otherwise, do zero padding |
|
|
144 |
else: |
|
|
145 |
tmp_slice = numpy.zeros(self.options.sliceResolution) |
|
|
146 |
tmp_slice_seg = numpy.zeros(self.options.sliceResolution) |
|
|
147 |
start_x = (self.options.sliceResolution[1] - slice_data.shape[1]) // 2 |
|
|
148 |
start_y = (self.options.sliceResolution[0] - slice_data.shape[0]) // 2 |
|
|
149 |
end_x = start_x + slice_data.shape[1] |
|
|
150 |
end_y = start_y + slice_data.shape[0] |
|
|
151 |
tmp_slice[start_y:end_y, start_x:end_x] = slice_data |
|
|
152 |
tmp_slice_seg[start_y:end_y, start_x:end_x] = slice_seg |
|
|
153 |
slice_data = tmp_slice |
|
|
154 |
slice_seg = tmp_slice_seg |
|
|
155 |
|
|
|
156 |
for angle in self.options.rotations: |
|
|
157 |
if angle != 0: |
|
|
158 |
slice_data_rotated = rotate(slice_data, angle, reshape=False) |
|
|
159 |
slice_seg_rotated = rotate(slice_seg, angle, reshape=False, mode='nearest') |
|
|
160 |
else: |
|
|
161 |
slice_data_rotated = slice_data |
|
|
162 |
slice_seg_rotated = slice_seg |
|
|
163 |
|
|
|
164 |
# Either collect crops |
|
|
165 |
if self.options.useCrops: |
|
|
166 |
if self.options.cropType == 'random': |
|
|
167 |
rx = numpy.random.randint(0, high=(slice_data_rotated.shape[1] - self.options.cropWidth), |
|
|
168 |
size=self.options.numRandomCropsPerSlice) |
|
|
169 |
ry = numpy.random.randint(0, high=(slice_data_rotated.shape[0] - self.options.cropHeight), |
|
|
170 |
size=self.options.numRandomCropsPerSlice) |
|
|
171 |
for r in range(self.options.numRandomCropsPerSlice): |
|
|
172 |
_images.append(crop(slice_data_rotated, ry[r], rx[r], self.options.cropHeight, self.options.cropWidth)) |
|
|
173 |
_labels.append(crop(slice_data_rotated, ry[r], rx[r], self.options.cropHeight, self.options.cropWidth)) |
|
|
174 |
_sets.append(_set_of_current_patient) |
|
|
175 |
elif self.options.cropType == 'center': |
|
|
176 |
slice_data_cropped = crop_center(slice_data_rotated, self.options.cropWidth, self.options.cropHeight) |
|
|
177 |
slice_seg_cropped = crop_center(slice_seg_rotated, self.options.cropWidth, self.options.cropHeight) |
|
|
178 |
_images.append(slice_data_cropped) |
|
|
179 |
_labels.append(slice_seg_cropped) |
|
|
180 |
_sets.append(_set_of_current_patient) |
|
|
181 |
# Or whole slices |
|
|
182 |
else: |
|
|
183 |
_images.append(slice_data_rotated) |
|
|
184 |
_labels.append(slice_seg_rotated) |
|
|
185 |
_sets.append(_set_of_current_patient) |
|
|
186 |
|
|
|
187 |
self._images = numpy.array(_images).astype(numpy.float32) |
|
|
188 |
self._labels = numpy.array(_labels).astype(numpy.float32) |
|
|
189 |
# assert numpy.max(self._images) <= 1.0, "MINC range is outside [0; 1]!" |
|
|
190 |
if self._images.ndim < 4: |
|
|
191 |
self._images = numpy.expand_dims(self._images, 3) |
|
|
192 |
self._sets = numpy.array(_sets).astype(numpy.int32) |
|
|
193 |
self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
194 |
self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} |
|
|
195 |
|
|
|
196 |
if self.options.cache: |
|
|
197 |
write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name()) |
|
|
198 |
tmp = copy.copy(self) |
|
|
199 |
tmp._images = None |
|
|
200 |
tmp._labels = None |
|
|
201 |
tmp._sets = None |
|
|
202 |
f = open(self.pckl_name(), 'wb') |
|
|
203 |
pickle.dump(tmp, f) |
|
|
204 |
f.close() |
|
|
205 |
|
|
|
206 |
def _get_patients(self): |
|
|
207 |
return BRAINWEB.get_patients(self.options) |
|
|
208 |
|
|
|
209 |
@staticmethod |
|
|
210 |
def get_patients(options): |
|
|
211 |
minc_folders = [options.folderNormal, options.folderMildMS, options.folderModerateMS, options.folderSevereMS] |
|
|
212 |
|
|
|
213 |
# Iterate over all folders and collect patients |
|
|
214 |
patients = [] |
|
|
215 |
for n, minc_folder in enumerate(minc_folders): |
|
|
216 |
if minc_folder == options.folderNormal: |
|
|
217 |
_type = 'NORMAL' |
|
|
218 |
elif minc_folder == options.folderMildMS: |
|
|
219 |
_type = 'MILDMS' |
|
|
220 |
elif minc_folder == options.folderModerateMS: |
|
|
221 |
_type = 'MODERATEMS' |
|
|
222 |
elif minc_folder == options.folderSevereMS: |
|
|
223 |
_type = 'SEVEREMS' |
|
|
224 |
|
|
|
225 |
# Continue with the next patient if the current one is not part of the desired types |
|
|
226 |
if _type not in options.filterType: |
|
|
227 |
continue |
|
|
228 |
|
|
|
229 |
if options.filterProtocol: |
|
|
230 |
_regex = BRAINWEB.PROTOCOL_MAPPINGS[options.filterProtocol] + ".mnc.gz" |
|
|
231 |
else: |
|
|
232 |
_regex = "*.mnc.gz" |
|
|
233 |
_files = glob.glob(os.path.join(options.dir, minc_folder, _regex)) |
|
|
234 |
for f, fname in enumerate(_files): |
|
|
235 |
patient = { |
|
|
236 |
'name': os.path.basename(fname), |
|
|
237 |
'type': _type, |
|
|
238 |
'fullpath': fname |
|
|
239 |
} |
|
|
240 |
patient['filtered_files'] = patient['fullpath'] |
|
|
241 |
|
|
|
242 |
if patient['type'] == 'NORMAL': |
|
|
243 |
patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'normal.mnc.gz') |
|
|
244 |
elif patient['type'] == 'MILDMS': |
|
|
245 |
patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'mild_lesions.mnc.gz') |
|
|
246 |
elif patient['type'] == 'MODERATEMS': |
|
|
247 |
patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'moderate_lesions.mnc.gz') |
|
|
248 |
elif patient['type'] == 'SEVEREMS': |
|
|
249 |
patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'severe_lesions.mnc.gz') |
|
|
250 |
|
|
|
251 |
patients.append(patient) |
|
|
252 |
|
|
|
253 |
return patients |
|
|
254 |
|
|
|
255 |
def load_volume_and_groundtruth(self, minc_filename, patient): |
|
|
256 |
minc_filename = patient['fullpath'] |
|
|
257 |
try: |
|
|
258 |
minc = MINC(minc_filename) # NII also works with MINC |
|
|
259 |
minc.set_view_mapping(BRAINWEB.VIEW_MAPPING) |
|
|
260 |
except: |
|
|
261 |
print('BRAINWEB: Failed to open file ' + minc_filename) |
|
|
262 |
|
|
|
263 |
# Try to load the segmentation ground-truth |
|
|
264 |
minc_seg_path = patient["groundtruth_filename"] |
|
|
265 |
minc_seg = MINC(minc_seg_path) |
|
|
266 |
skullmap = MINC(minc_seg_path) |
|
|
267 |
skullmap.data = (skullmap.data * 0.0) + 1.0 |
|
|
268 |
skullmap.set_view_mapping(BRAINWEB.VIEW_MAPPING) |
|
|
269 |
minc_seg.set_view_mapping(BRAINWEB.VIEW_MAPPING) |
|
|
270 |
|
|
|
271 |
# If desired, compute the skullmap |
|
|
272 |
if self.options.skullRemoval: |
|
|
273 |
skullmap.data[minc_seg.data == BRAINWEB.LABELS['FAT']] = 0 |
|
|
274 |
skullmap.data[minc_seg.data == BRAINWEB.LABELS['MUSCLE']] = 0 |
|
|
275 |
skullmap.data[minc_seg.data == BRAINWEB.LABELS['SKIN']] = 0 |
|
|
276 |
skullmap.data[minc_seg.data == BRAINWEB.LABELS['SKULL']] = 0 |
|
|
277 |
skullmap.data[minc_seg.data == BRAINWEB.LABELS['CONNECTIVE']] = 0 |
|
|
278 |
|
|
|
279 |
if self.options.backgroundRemoval: |
|
|
280 |
skullmap.data[minc_seg.data == BRAINWEB.LABELS['BACKGROUND']] = 0 |
|
|
281 |
|
|
|
282 |
# Binarize minc_seg |
|
|
283 |
lesion_idx = (minc_seg.data == BRAINWEB.LABELS['LESION']) |
|
|
284 |
nonlesion_idx = (minc_seg.data != BRAINWEB.LABELS['LESION']) |
|
|
285 |
minc_seg.data[lesion_idx] = 1 |
|
|
286 |
minc_seg.data[nonlesion_idx] = 0 |
|
|
287 |
|
|
|
288 |
if self.options.skullRemoval or self.options.backgroundRemoval: |
|
|
289 |
minc.apply_skullmap(skullmap) |
|
|
290 |
|
|
|
291 |
# In-place normalize the loaded volume |
|
|
292 |
minc.normalize(method=self.options.normalizationMethod, lowerpercentile=0.0, upperpercentile=99.8) |
|
|
293 |
# 99.8 percentile described in LG Ny´ul, Jayaram K Udupa, and Xuan Zhang. |
|
|
294 |
# New variants of a method of MRI scale standardization. |
|
|
295 |
# IEEE transactions on medical imaging, 19(2):143–150, 2000. |
|
|
296 |
# assert numpy.max(minc.getData()) <= 1.0, "MINC range is outside [0; 1]!" |
|
|
297 |
|
|
|
298 |
return minc, minc_seg, skullmap |
|
|
299 |
|
|
|
300 |
# Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice |
|
|
301 |
def get_patient_idx(self, split='TRAIN'): |
|
|
302 |
idx = [] |
|
|
303 |
for pidx, patient in enumerate(self.patients): |
|
|
304 |
if patient["name"] in self.patients_split[split]: |
|
|
305 |
idx += [pidx] |
|
|
306 |
return idx |
|
|
307 |
|
|
|
308 |
def get_patient_split(self): |
|
|
309 |
return self.patients_split |
|
|
310 |
|
|
|
311 |
@property |
|
|
312 |
def images(self): |
|
|
313 |
return self._images |
|
|
314 |
|
|
|
315 |
def get_images(self, set=None): |
|
|
316 |
_setIdx = BRAINWEB.SET_TYPES.index(set) |
|
|
317 |
images_in_set = numpy.where(self._sets == _setIdx)[0] |
|
|
318 |
return self._images[images_in_set] |
|
|
319 |
|
|
|
320 |
def get_image(self, i): |
|
|
321 |
return self._images[i, :, :, :] |
|
|
322 |
|
|
|
323 |
def get_label(self, i): |
|
|
324 |
return self._labels[i, :, :, :] |
|
|
325 |
|
|
|
326 |
@property |
|
|
327 |
def labels(self): |
|
|
328 |
return self._labels |
|
|
329 |
|
|
|
330 |
@property |
|
|
331 |
def sets(self): |
|
|
332 |
return self._sets |
|
|
333 |
|
|
|
334 |
@property |
|
|
335 |
def meta(self): |
|
|
336 |
return self._meta |
|
|
337 |
|
|
|
338 |
@property |
|
|
339 |
def num_examples(self): |
|
|
340 |
return self._images.shape[0] |
|
|
341 |
|
|
|
342 |
@property |
|
|
343 |
def width(self): |
|
|
344 |
return self._images.shape[2] |
|
|
345 |
|
|
|
346 |
@property |
|
|
347 |
def height(self): |
|
|
348 |
return self._images.shape[1] |
|
|
349 |
|
|
|
350 |
@property |
|
|
351 |
def num_channels(self): |
|
|
352 |
return self._images.shape[3] |
|
|
353 |
|
|
|
354 |
@property |
|
|
355 |
def epochs_completed(self): |
|
|
356 |
return self._epochs_completed |
|
|
357 |
|
|
|
358 |
def name(self): |
|
|
359 |
_name = "BRAINWEB" |
|
|
360 |
if self.options.description: |
|
|
361 |
_name += "_{}".format(self.options.description) |
|
|
362 |
if self.options.numSamples > 0: |
|
|
363 |
_name += '_n{}'.format(self.options.numSamples) |
|
|
364 |
_name += "_p{}-{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL'], self.options.partition['TEST']) |
|
|
365 |
if self.options.useCrops: |
|
|
366 |
_name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight) |
|
|
367 |
if self.options.cropType == "random": |
|
|
368 |
_name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice) |
|
|
369 |
if self.options.sliceResolution is not None: |
|
|
370 |
_name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1]) |
|
|
371 |
if self.options.skullRemoval: |
|
|
372 |
_name += "_noSkull" |
|
|
373 |
if self.options.backgroundRemoval: |
|
|
374 |
_name += "_noBackground" |
|
|
375 |
return _name |
|
|
376 |
|
|
|
377 |
def pckl_name(self): |
|
|
378 |
return os.path.join(self.dir(), self.name() + ".pckl") |
|
|
379 |
|
|
|
380 |
def tfrecord_name(self): |
|
|
381 |
return os.path.join(self.dir(), self.name() + ".tfrecord") |
|
|
382 |
|
|
|
383 |
def split_name(self): |
|
|
384 |
return os.path.join(self.dir(), |
|
|
385 |
'split-{}-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL'], self.options.partition['TEST'])) |
|
|
386 |
|
|
|
387 |
def dir(self): |
|
|
388 |
return self.options.dir |
|
|
389 |
|
|
|
390 |
def export_slices(self, dir): |
|
|
391 |
for i in range(self.num_examples): |
|
|
392 |
imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8')) |
|
|
393 |
|
|
|
394 |
def visualize(self, pause=1, set='TRAIN'): |
|
|
395 |
f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2) |
|
|
396 |
images_tmp, labels_tmp, _ = self.next_batch(10, set=set) |
|
|
397 |
for i in range(images_tmp.shape[0]): |
|
|
398 |
img = numpy.squeeze(images_tmp[i]) |
|
|
399 |
lbl = numpy.squeeze(labels_tmp[i]) |
|
|
400 |
ax1.imshow(img) |
|
|
401 |
ax1.set_title('Patch') |
|
|
402 |
ax2.imshow(lbl) |
|
|
403 |
ax2.set_title('Groundtruth') |
|
|
404 |
matplotlib.pyplot.pause(pause) |
|
|
405 |
|
|
|
406 |
def num_batches(self, batchsize, set='TRAIN'): |
|
|
407 |
_setIdx = BRAINWEB.SET_TYPES.index(set) |
|
|
408 |
images_in_set = numpy.where(self._sets == _setIdx)[0] |
|
|
409 |
return len(images_in_set) // batchsize |
|
|
410 |
|
|
|
411 |
def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=False): |
|
|
412 |
"""Return the next `batch_size` examples from this data set.""" |
|
|
413 |
_setIdx = BRAINWEB.SET_TYPES.index(set) |
|
|
414 |
images_in_set = numpy.where(self._sets == _setIdx)[0] |
|
|
415 |
samples_in_set = len(images_in_set) |
|
|
416 |
|
|
|
417 |
start = self._index_in_epoch[set] |
|
|
418 |
# Shuffle for the first epoch |
|
|
419 |
if self._epochs_completed == 0 and start == 0 and shuffle: |
|
|
420 |
perm0 = numpy.arange(samples_in_set) |
|
|
421 |
numpy.random.shuffle(perm0) |
|
|
422 |
self._images[images_in_set] = self.images[images_in_set[perm0]] |
|
|
423 |
self._labels[images_in_set] = self.labels[images_in_set[perm0]] |
|
|
424 |
self._sets[images_in_set] = self.sets[images_in_set[perm0]] |
|
|
425 |
|
|
|
426 |
# Go to the next epoch |
|
|
427 |
if start + batch_size > samples_in_set: |
|
|
428 |
# Finished epoch |
|
|
429 |
self._epochs_completed[set] += 1 |
|
|
430 |
|
|
|
431 |
# Get the rest examples in this epoch |
|
|
432 |
rest_num_examples = samples_in_set - start |
|
|
433 |
images_rest_part = self._images[images_in_set[start:samples_in_set]] |
|
|
434 |
labels_rest_part = self._labels[images_in_set[start:samples_in_set]] |
|
|
435 |
|
|
|
436 |
# Shuffle the data |
|
|
437 |
if shuffle: |
|
|
438 |
perm = numpy.arange(samples_in_set) |
|
|
439 |
numpy.random.shuffle(perm) |
|
|
440 |
self._images[images_in_set] = self.images[images_in_set[perm]] |
|
|
441 |
self._labels[images_in_set] = self.labels[images_in_set[perm]] |
|
|
442 |
self._sets[images_in_set] = self.sets[images_in_set[perm]] |
|
|
443 |
|
|
|
444 |
# Start next epoch |
|
|
445 |
start = 0 |
|
|
446 |
self._index_in_epoch[set] = batch_size - rest_num_examples |
|
|
447 |
end = self._index_in_epoch[set] |
|
|
448 |
images_new_part = self._images[images_in_set[start:end]] |
|
|
449 |
labels_new_part = self._labels[images_in_set[start:end]] |
|
|
450 |
|
|
|
451 |
images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0) |
|
|
452 |
labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0) |
|
|
453 |
else: |
|
|
454 |
self._index_in_epoch[set] += batch_size |
|
|
455 |
end = self._index_in_epoch[set] |
|
|
456 |
images_tmp = self._images[images_in_set[start:end]] |
|
|
457 |
labels_tmp = self._labels[images_in_set[start:end]] |
|
|
458 |
|
|
|
459 |
if self.options.addInstanceNoise: |
|
|
460 |
noise = numpy.random.normal(0, 0.01, images_tmp.shape) |
|
|
461 |
images_tmp += noise |
|
|
462 |
|
|
|
463 |
# Check the batch |
|
|
464 |
assert images_tmp.size, "The batch is empty!" |
|
|
465 |
assert labels_tmp.size, "The labels of the current batch are empty!" |
|
|
466 |
|
|
|
467 |
if return_brainmask: |
|
|
468 |
brainmasks = np.copy(labels_tmp) |
|
|
469 |
brainmasks[brainmasks == BRAINWEB.LABELS['FAT']] = 0 |
|
|
470 |
brainmasks[brainmasks == BRAINWEB.LABELS['MUSCLE']] = 0 |
|
|
471 |
brainmasks[brainmasks == BRAINWEB.LABELS['SKIN']] = 0 |
|
|
472 |
brainmasks[brainmasks == BRAINWEB.LABELS['SKULL']] = 0 |
|
|
473 |
brainmasks[brainmasks == BRAINWEB.LABELS['CONNECTIVE']] = 0 |
|
|
474 |
brainmasks[brainmasks == BRAINWEB.LABELS['BACKGROUND']] = 0 |
|
|
475 |
brainmasks[brainmasks > 0] = 1 |
|
|
476 |
return images_tmp, labels_tmp, brainmasks |
|
|
477 |
|
|
|
478 |
return images_tmp, labels_tmp, None |
|
|
479 |
|
|
|
480 |
def _convert_patient_split(self): |
|
|
481 |
for split in self.patients_split.keys(): |
|
|
482 |
_list_of_patient_names = [] |
|
|
483 |
for pidx in self.patients_split[split]: |
|
|
484 |
if not isinstance(pidx, str): |
|
|
485 |
_list_of_patient_names += [self.patients[pidx]['name']] |
|
|
486 |
else: |
|
|
487 |
_list_of_patient_names = self.patients_split[split] |
|
|
488 |
break |
|
|
489 |
self.patients_split[split] = _list_of_patient_names |
|
|
490 |
|
|
|
491 |
f = open(self.split_name(), 'wb') |
|
|
492 |
pickle.dump(self.patients_split, f) |
|
|
493 |
f.close() |