Diff of /acdc_data_preparation.py [000000] .. [98e649]

Switch to unified view

a b/acdc_data_preparation.py
1
import numpy as np
2
import os, sys, shutil, time, re
3
import h5py
4
import skimage.morphology as morph
5
from tqdm import tqdm
6
import matplotlib.pyplot as plt
7
from matplotlib import animation
8
import time
9
import pickle
10
# For ROI extraction
11
import skimage.transform
12
from scipy.fftpack import fftn, ifftn
13
from skimage.feature import peak_local_max, canny
14
from skimage.transform import hough_circle 
15
# Nifti processing
16
import nibabel as nib
17
from collections import OrderedDict
18
# print sys.path
19
# sys.path.append("..") 
20
import errno
21
np.random.seed(42)
22
23
# Helper functions
24
## Heart Metrics
25
def heart_metrics(seg_3Dmap, voxel_size, classes=[3, 1, 2]):
26
    """
27
    Compute the volumes of each classes
28
    """
29
    # Loop on each classes of the input images
30
    volumes = []
31
    for c in classes:
32
        # Copy the gt image to not alterate the input
33
        seg_3Dmap_copy = np.copy(seg_3Dmap)
34
        seg_3Dmap_copy[seg_3Dmap_copy != c] = 0
35
36
        # Clip the value to compute the volumes
37
        seg_3Dmap_copy = np.clip(seg_3Dmap_copy, 0, 1)
38
39
        # Compute volume
40
        volume = seg_3Dmap_copy.sum() * np.prod(voxel_size) / 1000.
41
        volumes += [volume]
42
    return volumes
43
44
def ejection_fraction(ed_vol, es_vol):
45
    """
46
    Calculate ejection fraction
47
    """
48
    stroke_vol = ed_vol - es_vol
49
    return (np.float(stroke_vol)/np.float(ed_vol))*100
50
51
def myocardialmass(myocardvol):
52
    """
53
    Specific gravity of heart muscle (1.05 g/ml)
54
    """ 
55
    return myocardvol*1.05
56
def imshow(*args,**kwargs):
57
    """ Handy function to show multiple plots in on row, possibly with different cmaps and titles
58
    Usage: 
59
    imshow(img1, title="myPlot")
60
    imshow(img1,img2, title=['title1','title2'])
61
    imshow(img1,img2, cmap='hot')
62
    imshow(img1,img2,cmap=['gray','Blues']) """
63
    cmap = kwargs.get('cmap', 'gray')
64
    title= kwargs.get('title','')
65
    if len(args)==0:
66
        raise ValueError("No images given to imshow")
67
    elif len(args)==1:
68
        plt.title(title)
69
        plt.imshow(args[0], interpolation='none')
70
    else:
71
        n=len(args)
72
        if type(cmap)==str:
73
            cmap = [cmap]*n
74
        if type(title)==str:
75
            title= [title]*n
76
        plt.figure(figsize=(n*5,10))
77
        for i in range(n):
78
            plt.subplot(1,n,i+1)
79
            plt.title(title[i])
80
            plt.imshow(args[i], cmap[i])
81
    plt.show()
82
    
83
def plot_roi(data4D, roi_center, roi_radii):
84
    """
85
    Do the animation of full heart volume
86
    """
87
    x_roi_center, y_roi_center = roi_center[0], roi_center[1]
88
    x_roi_radius, y_roi_radius = roi_radii[0], roi_radii[1]
89
    print ('nslices', data4D.shape[2])
90
91
    zslices = data4D.shape[2]
92
    tframes = data4D.shape[3]
93
94
    slice_cnt = 0
95
    for slice in [data4D[:,:,z,:] for z in range(zslices)]:
96
      outdata = np.swapaxes(np.swapaxes(slice[:,:,:], 0,2), 1,2)
97
      roi_mask = np.zeros_like(outdata[0])
98
      roi_mask[x_roi_center - x_roi_radius:x_roi_center + x_roi_radius,
99
      y_roi_center - y_roi_radius:y_roi_center + y_roi_radius] = 1
100
101
      outdata[:, roi_mask > 0.5] = 0.8 * outdata[:, roi_mask > 0.5]
102
      outdata[:, roi_mask > 0.5] = 0.8 * outdata[:, roi_mask > 0.5]
103
104
      fig = plt.figure(1)
105
      fig.canvas.set_window_title('slice_No' + str(slice_cnt))
106
      slice_cnt+=1
107
      def init_out():
108
          im.set_data(outdata[0])
109
110
      def animate_out(i):
111
          im.set_data(outdata[i])
112
          return im
113
114
      im = fig.gca().imshow(outdata[0], cmap='gray')
115
      anim = animation.FuncAnimation(fig, animate_out, init_func=init_out, frames=tframes, interval=50)
116
      anim.save('Cine_MRI_SAX_%d.mp4'%slice_cnt, fps=50, extra_args=['-vcodec', 'libx264'])
117
      plt.show()
118
        
119
def plot_4D(data4D):
120
    """
121
    Do the animation of full heart volume
122
    """
123
    print ('nslices', data4D.shape[2])
124
    zslices = data4D.shape[2]
125
    tframes = data4D.shape[3]
126
127
    slice_cnt = 0
128
    for slice in [data4D[:,:,z,:] for z in range(zslices)]:
129
      outdata = np.swapaxes(np.swapaxes(slice[:,:,:], 0,2), 1,2)
130
      fig = plt.figure(1)
131
      fig.canvas.set_window_title('slice_No' + str(slice_cnt))
132
      slice_cnt+=1
133
      def init_out():
134
          im.set_data(outdata[0])
135
136
      def animate_out(i):
137
          im.set_data(outdata[i])
138
          return im
139
140
      im = fig.gca().imshow(outdata[0], cmap='gray')
141
      anim = animation.FuncAnimation(fig, animate_out, init_func=init_out, frames=tframes, interval=50)
142
      plt.show()
143
144
145
def multilabel_split(image_tensor):
146
    """
147
    image_tensor : Batch * H * W
148
    Split multilabel images and return stack of images
149
    Returns: Tensor of shape: Batch * H * W * n_class (4D tensor)
150
    # TODO: Be careful: when using this code: labels need to be 
151
    defined, explictly before hand as this code does not handle
152
    missing labels
153
    So far, this function is okay as it considers full volume for
154
    finding out unique labels
155
    """
156
    labels = np.unique(image_tensor)
157
    batch_size = image_tensor.shape[0]
158
    out_shape =  image_tensor.shape + (len(labels),)
159
    image_tensor_4D = np.zeros(out_shape, dtype='uint8')
160
    for i in xrange(batch_size):
161
        cnt = 0
162
        shape =image_tensor.shape[1:3] + (len(labels),)
163
        temp = np.ones(shape, dtype='uint8')
164
        for label in labels:
165
            temp[...,cnt] = np.where(image_tensor[i] == label, temp[...,cnt], 0)
166
            cnt += 1
167
        image_tensor_4D[i] = temp
168
    return image_tensor_4D
169
170
def save_data(data, filename, out_path):
171
    out_filename = os.path.join(out_path, filename)
172
    with open(out_filename, 'wb') as f:
173
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
174
    print ('saved to %s' % out_filename)
175
176
def load_pkl(path):
177
    with open(path, 'rb') as f:
178
        obj = pickle.load(f)
179
    return obj
180
181
### Stratified Sampling of data
182
183
# Refer:
184
# http://www.echopedia.org/wiki/Left_Ventricular_Dimensions
185
# https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html
186
# https://en.wikipedia.org/wiki/Body_surface_area
187
# 30 normal subjects - NOR
188
NORMAL = 'NOR'
189
# 30 patients with previous myocardial infarction 
190
# (ejection fraction of the left ventricle lower than 40% and several myocardial segments with abnormal contraction) - MINF
191
MINF = 'MINF'
192
# 30 patients with dilated cardiomyopathy 
193
# (diastolic left ventricular volume >100 mL/m2 and an ejection fraction of the left ventricle lower than 40%) - DCM
194
DCM = 'DCM'
195
# 30 patients with hypertrophic cardiomyopathy 
196
# (left ventricular cardiac mass high than 110 g/m2,
197
# several myocardial segments with a thickness higher than 15 mm in diastole and a normal ejecetion fraction) - HCM
198
HCM = 'HCM'
199
# 30 patients with abnormal right ventricle (volume of the right ventricular 
200
# cavity higher than 110 mL/m2 or ejection fraction of the rigth ventricle lower than 40%) - RV
201
RV = 'RV'
202
def copy(src, dest):
203
  """
204
  Copy function
205
  """
206
  try:
207
      shutil.copytree(src, dest, ignore=shutil.ignore_patterns())
208
  except OSError as e:
209
      # If the error was caused because the source wasn't a directory
210
      if e.errno == errno.ENOTDIR:
211
          shutil.copy(src, dest)
212
      else:
213
          print('Directory not copied. Error: %s' % e)
214
215
def read_patient_cfg(path):
216
  """
217
  Reads patient data in the cfg file and returns a dictionary
218
  """
219
  patient_info = {}
220
  with open(os.path.join(path, 'Info.cfg')) as f_in:
221
    for line in f_in:
222
      l = line.rstrip().split(": ")
223
      patient_info[l[0]] = l[1]
224
  return patient_info
225
     
226
def group_patient_cases(src_path, out_path, force=False):
227
  """ Group the patient data according to cardiac pathology""" 
228
229
  cases = sorted(next(os.walk(src_path))[1])
230
  dest_path = os.path.join(out_path, 'Patient_Groups')
231
  if force:
232
    shutil.rmtree(dest_path)
233
  if os.path.exists(dest_path):
234
    return dest_path  
235
236
  os.makedirs(dest_path)
237
  os.mkdir(os.path.join(dest_path, NORMAL))
238
  os.mkdir(os.path.join(dest_path, MINF))
239
  os.mkdir(os.path.join(dest_path, DCM))
240
  os.mkdir(os.path.join(dest_path, HCM))
241
  os.mkdir(os.path.join(dest_path, RV))
242
243
  for case in cases:
244
    full_path = os.path.join(src_path, case)
245
    copy(full_path, os.path.join(dest_path,\
246
        read_patient_cfg(full_path)['Group'], case))
247
248
def generate_train_validate_test_set(src_path, dest_path):
249
  """
250
  Split the data into 70:15:15 for train-validate-test set
251
  arg: path: input data path
252
  """
253
  SPLIT_TRAIN = 0.7
254
  SPLIT_VALID = 0.15
255
256
  dest_path = os.path.join(dest_path,'dataset')
257
  if os.path.exists(dest_path):
258
    shutil.rmtree(dest_path)
259
  os.makedirs(os.path.join(dest_path, 'train_set'))  
260
  os.makedirs(os.path.join(dest_path, 'validation_set'))  
261
  os.makedirs(os.path.join(dest_path, 'test_set'))  
262
  # print (src_path)
263
  groups = next(os.walk(src_path))[1]
264
  for group in groups:
265
    group_path = next(os.walk(os.path.join(src_path, group)))[0]
266
    patient_folders = next(os.walk(group_path))[1]
267
    np.random.shuffle(patient_folders)
268
    train_ = patient_folders[0:int(SPLIT_TRAIN*len(patient_folders))]
269
    valid_ = patient_folders[int(SPLIT_TRAIN*len(patient_folders)): 
270
                 int((SPLIT_TRAIN+SPLIT_VALID)*len(patient_folders))]
271
    test_ = patient_folders[int((SPLIT_TRAIN+SPLIT_VALID)*len(patient_folders)):]
272
    for patient in train_:
273
      folder_path = os.path.join(group_path, patient)
274
      copy(folder_path, os.path.join(dest_path, 'train_set', patient))
275
276
    for patient in valid_:
277
      folder_path = os.path.join(group_path, patient)
278
      copy(folder_path, os.path.join(dest_path, 'validation_set', patient))
279
280
    for patient in test_:
281
      folder_path = os.path.join(group_path, patient)
282
      copy(folder_path, os.path.join(dest_path, 'test_set', patient))
283
284
#   Fourier-Hough Transform Based ROI Extraction
285
def extract_roi_fft(data4D, pixel_spacing, minradius_mm=15, maxradius_mm=45, kernel_width=5, 
286
                center_margin=8, num_peaks=10, num_circles=20, radstep=2):
287
    """
288
    Returns center and radii of ROI region in (i,j) format
289
    """
290
    # Data shape: 
291
    # radius of the smallest and largest circles in mm estimated from the train set
292
    # convert to pixel counts
293
294
    pixel_spacing_X, pixel_spacing_Y, _,_ = pixel_spacing
295
    minradius = int(minradius_mm / pixel_spacing_X)
296
    maxradius = int(maxradius_mm / pixel_spacing_Y)
297
298
    ximagesize = data4D.shape[0]
299
    yimagesize = data4D.shape[1]
300
    zslices = data4D.shape[2]
301
    tframes = data4D.shape[3]
302
    xsurface = np.tile(range(ximagesize), (yimagesize, 1)).T
303
    ysurface = np.tile(range(yimagesize), (ximagesize, 1))
304
    lsurface = np.zeros((ximagesize, yimagesize))
305
306
    allcenters = []
307
    allaccums = []
308
    allradii = []
309
310
    for slice in range(zslices):
311
        ff1 = fftn([data4D[:,:,slice, t] for t in range(tframes)])
312
        fh = np.absolute(ifftn(ff1[1, :, :]))
313
        fh[fh < 0.1 * np.max(fh)] = 0.0
314
        image = 1. * fh / np.max(fh)
315
        # find hough circles and detect two radii
316
        edges = canny(image, sigma=3)
317
        hough_radii = np.arange(minradius, maxradius, radstep)
318
        # print hough_radii
319
        hough_res = hough_circle(edges, hough_radii)
320
        if hough_res.any():
321
            centers = []
322
            accums = []
323
            radii = []
324
325
            for radius, h in zip(hough_radii, hough_res):
326
                # For each radius, extract num_peaks circles
327
                peaks = peak_local_max(h, num_peaks=num_peaks)
328
                centers.extend(peaks)
329
                accums.extend(h[peaks[:, 0], peaks[:, 1]])
330
                radii.extend([radius] * num_peaks)
331
  
332
            # Keep the most prominent num_circles circles
333
            sorted_circles_idxs = np.argsort(accums)[::-1][:num_circles]
334
335
            for idx in sorted_circles_idxs:
336
                center_x, center_y = centers[idx]
337
                allcenters.append(centers[idx])
338
                allradii.append(radii[idx])
339
                allaccums.append(accums[idx])
340
                brightness = accums[idx]
341
                lsurface = lsurface + brightness * np.exp(
342
                    -((xsurface - center_x) ** 2 + (ysurface - center_y) ** 2) / kernel_width ** 2)
343
344
    lsurface = lsurface / lsurface.max()
345
    # select most likely ROI center
346
    roi_center = np.unravel_index(lsurface.argmax(), lsurface.shape)
347
348
    # determine ROI radius
349
    roi_x_radius = 0
350
    roi_y_radius = 0
351
    for idx in range(len(allcenters)):
352
        xshift = np.abs(allcenters[idx][0] - roi_center[0])
353
        yshift = np.abs(allcenters[idx][1] - roi_center[1])
354
        if (xshift <= center_margin) & (yshift <= center_margin):
355
            roi_x_radius = np.max((roi_x_radius, allradii[idx] + xshift))
356
            roi_y_radius = np.max((roi_y_radius, allradii[idx] + yshift))
357
358
    if roi_x_radius > 0 and roi_y_radius > 0:
359
        roi_radii = roi_x_radius, roi_y_radius
360
    else:
361
        roi_radii = None
362
363
    return roi_center, roi_radii
364
365
#   Stddev-Hough Transform Based ROI Extraction
366
def extract_roi_stddev(data4D, pixel_spacing, minradius_mm=15, maxradius_mm=45, kernel_width=5, 
367
                center_margin=8, num_peaks=10, num_circles=20, radstep=2):
368
    """
369
    Returns center and radii of ROI region in (i,j) format
370
    """
371
    # Data shape: 
372
    # radius of the smallest and largest circles in mm estimated from the train set
373
    # convert to pixel counts
374
375
    pixel_spacing_X, pixel_spacing_Y, _,_ = pixel_spacing
376
    minradius = int(minradius_mm / pixel_spacing_X)
377
    maxradius = int(maxradius_mm / pixel_spacing_Y)
378
379
    ximagesize = data4D.shape[0]
380
    yimagesize = data4D.shape[1]
381
    zslices = data4D.shape[2]
382
    tframes = data4D.shape[3]
383
    xsurface = np.tile(range(ximagesize), (yimagesize, 1)).T
384
    ysurface = np.tile(range(yimagesize), (ximagesize, 1))
385
    lsurface = np.zeros((ximagesize, yimagesize))
386
387
    allcenters = []
388
    allaccums = []
389
    allradii = []
390
391
    for slice in range(zslices):
392
        ff1 = np.array([data4D[:,:,slice, t] for t in range(tframes)])
393
        fh = np.std(ff1, axis=0)
394
        fh[fh < 0.1 * np.max(fh)] = 0.0
395
        image = 1. * fh / np.max(fh)
396
        # find hough circles and detect two radii
397
        edges = canny(image, sigma=3)
398
        hough_radii = np.arange(minradius, maxradius, radstep)
399
        # print hough_radii
400
        hough_res = hough_circle(edges, hough_radii)
401
        if hough_res.any():
402
            centers = []
403
            accums = []
404
            radii = []
405
            for radius, h in zip(hough_radii, hough_res):
406
                # For each radius, extract num_peaks circles
407
                peaks = peak_local_max(h, num_peaks=num_peaks)
408
                centers.extend(peaks)
409
                accums.extend(h[peaks[:, 0], peaks[:, 1]])
410
                radii.extend([radius] * num_peaks)
411
  
412
            # Keep the most prominent num_circles circles
413
            sorted_circles_idxs = np.argsort(accums)[::-1][:num_circles]
414
415
            for idx in sorted_circles_idxs:
416
                center_x, center_y = centers[idx]
417
                allcenters.append(centers[idx])
418
                allradii.append(radii[idx])
419
                allaccums.append(accums[idx])
420
                brightness = accums[idx]
421
                lsurface = lsurface + brightness * np.exp(
422
                    -((xsurface - center_x) ** 2 + (ysurface - center_y) ** 2) / kernel_width ** 2)
423
424
    lsurface = lsurface / lsurface.max()
425
    # select most likely ROI center
426
    roi_center = np.unravel_index(lsurface.argmax(), lsurface.shape)
427
428
    # determine ROI radius
429
    roi_x_radius = 0
430
    roi_y_radius = 0
431
    for idx in range(len(allcenters)):
432
        xshift = np.abs(allcenters[idx][0] - roi_center[0])
433
        yshift = np.abs(allcenters[idx][1] - roi_center[1])
434
        if (xshift <= center_margin) & (yshift <= center_margin):
435
            roi_x_radius = np.max((roi_x_radius, allradii[idx] + xshift))
436
            roi_y_radius = np.max((roi_y_radius, allradii[idx] + yshift))
437
438
    if roi_x_radius > 0 and roi_y_radius > 0:
439
        roi_radii = roi_x_radius, roi_y_radius
440
    else:
441
        roi_radii = None
442
443
    return roi_center, roi_radii
444
445
446
class Dataset(object):
447
    def __init__(self, directory, subdir):
448
        # type: (object, object) -> object
449
        self.patient_data = {}
450
        self.directory = directory
451
        self.name = subdir
452
453
    def _filename(self, file):
454
        return os.path.join(self.directory, self.name, file)
455
456
    def load_nii(self, img_path):
457
        """
458
        Function to load a 'nii' or 'nii.gz' file, The function returns
459
        everyting needed to save another 'nii' or 'nii.gz'
460
        in the same dimensional space, i.e. the affine matrix and the header
461
462
        Parameters
463
        ----------
464
465
        img_path: string
466
        String with the path of the 'nii' or 'nii.gz' image file name.
467
468
        Returns
469
        -------
470
        Three element, the first is a numpy array of the image values,
471
        the second is the affine transformation of the image, and the
472
        last one is the header of the image.
473
        """
474
        nimg = nib.load(self._filename(img_path))
475
        return nimg.get_data(), nimg.affine, nimg.header
476
    
477
    def read_patient_info_data(self):
478
        """
479
        Reads patient data in the cfg file from patient folder 
480
        using Info.cfg
481
        """
482
        print (self._filename('Info.cfg'))
483
        with open(self._filename('Info.cfg')) as f_in:
484
            for line in f_in:
485
              l = line.rstrip().split(": ")
486
              self.patient_data[l[0]] = l[1]
487
488
    def read_patient_data(self, mode='train', roi_detect=True):
489
        """
490
        Reads patient data in the cfg file and returns a dictionary and
491
        extract End diastole and End Systole image from patient folder
492
        using Info.cfg
493
        """
494
        self.read_patient_info_data()
495
        # Read patient Number
496
        m = re.match("patient(\d{3})", self.name)
497
        patient_No = int(m.group(1))
498
        # Read Diastole frame Number
499
        ED_frame_No = int(self.patient_data['ED'])
500
        ed_img = "patient%03d_frame%02d.nii.gz" %(patient_No, ED_frame_No)
501
        ed, affine, hdr  = self.load_nii(ed_img)
502
        # Read Systole frame Number
503
        ES_frame_No = int(self.patient_data['ES'])
504
        es_img = "patient%03d_frame%02d.nii.gz" %(patient_No, ES_frame_No)
505
        es, _, _  = self.load_nii(es_img)
506
        # Save Images:
507
        self.patient_data['ED_VOL'] = ed
508
        self.patient_data['ES_VOL'] = es
509
 
510
        # Header Info for saving    
511
        header_info ={'affine':affine, 'hdr': hdr}
512
        self.patient_data['header'] = header_info
513
        if mode == 'reader':
514
            # Read a particular volume number in 4D image
515
            img_4d_name = "patient%03d_4d.nii.gz"%patient_No
516
            # Load data
517
            img_4D, _, hdr = self.load_nii(img_4d_name)
518
            self.patient_data['4D'] = img_4D
519
520
            ed_gt, _, _  = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ED_frame_No))
521
            es_gt, _, _  = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ES_frame_No))
522
            ed_lv, ed_rv, ed_myo = heart_metrics(ed_gt, hdr.get_zooms()) 
523
            es_lv, es_rv, es_myo = heart_metrics(es_gt, hdr.get_zooms())
524
            ef_lv = ejection_fraction(ed_lv, es_lv)
525
            ef_rv = ejection_fraction(ed_rv, es_rv)
526
            heart_param = {'EDV_LV': ed_lv, 'EDV_RV': ed_rv, 'ESV_LV': es_lv, 'ESV_RV': es_rv,
527
                           'ED_MYO': ed_myo, 'ES_MYO': es_myo, 'EF_LV': ef_lv, 'EF_RV': ef_rv}  
528
            self.patient_data['HP'] = heart_param 
529
            self.patient_data['ED_GT'] = ed_gt
530
            self.patient_data['ES_GT'] = es_gt
531
            return
532
533
        if mode == 'train':
534
            ed_gt, _, _  = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ED_frame_No))
535
            es_gt, _, _  = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ES_frame_No))
536
            ed_lv, ed_rv, ed_myo = heart_metrics(ed_gt, hdr.get_zooms()) 
537
            es_lv, es_rv, es_myo = heart_metrics(es_gt, hdr.get_zooms())
538
            ef_lv = ejection_fraction(ed_lv, es_lv)
539
            ef_rv = ejection_fraction(ed_rv, es_rv)
540
            heart_param = {'EDV_LV': ed_lv, 'EDV_RV': ed_rv, 'ESV_LV': es_lv, 'ESV_RV': es_rv,
541
                           'ED_MYO': ed_myo, 'ES_MYO': es_myo, 'EF_LV': ef_lv, 'EF_RV': ef_rv}  
542
            self.patient_data['HP'] = heart_param 
543
            self.patient_data['ED_GT'] = ed_gt
544
            self.patient_data['ES_GT'] = es_gt
545
546
        if mode == 'tester':
547
            # Read a particular volume number in 4D image
548
            img_4d_name = "patient%03d_4d.nii.gz"%patient_No
549
            # Load data
550
            img_4D, _, hdr = self.load_nii(img_4d_name)
551
            self.patient_data['4D'] = img_4D
552
553
        if roi_detect:
554
            # Read a particular volume number in 4D image
555
            img_4d_name = "patient%03d_4d.nii.gz"%patient_No
556
            # Load data
557
            img_4D, _, hdr = self.load_nii(img_4d_name)
558
            c, r = extract_roi_stddev(img_4D, hdr.get_zooms()) 
559
            self.patient_data['roi_center'], self.patient_data['roi_radii']=c,r 
560
            self.patient_data['4D'] = img_4D
561
#             print c, r
562
#             plot_roi(img_4D, c,r)
563
            
564
def convert_nii_np(data_path, mode, roi_detect):
565
    """
566
    Prepare a dictionary of dataset and save it as numpy file
567
    """
568
    patient_fulldata = OrderedDict()
569
    print (data_path)
570
    patient_folders = next(os.walk(data_path))[1]
571
    for patient in tqdm(sorted(patient_folders)):
572
#         print (patient)
573
        dset = Dataset(data_path, patient)
574
        dset.read_patient_data(mode=mode, roi_detect=roi_detect)
575
        patient_fulldata[dset.name] = dset.patient_data
576
    return patient_fulldata
577
578
if __name__ == '__main__':
579
  start_time = time.time()
580
  # Path to ACDC training database
581
  complete_data_path = '../../ACDC_DataSet/training'
582
  dest_path = '../../processed_acdc_dataset'
583
  group_path = '../../processed_acdc_dataset/Patient_Groups'
584
585
  # Training dataset
586
  train_dataset = '../../processed_acdc_dataset/dataset/train_set'
587
  validation_dataset = '../../processed_acdc_dataset/dataset/validation_set'
588
  test_dataset = '../../processed_acdc_dataset/dataset/test_set'
589
  out_path_train = '../../processed_acdc_dataset/pickled/full_data'
590
  hdf5_out_path = '../../processed_acdc_dataset/hdf5_files'
591
  #Final Test dataset
592
  final_testing_dataset = '../../ACDC_DataSet/testing'
593
  out_path_test = '../../processed_acdc_dataset/pickled/final_test'
594
595
  # First perform stratified sampling
596
  group_patient_cases(complete_data_path, dest_path)
597
  generate_train_validate_test_set(group_path, dest_path)
598
  print("---Time taken to stratify the dataset %s seconds ---" % (time.time() - start_time))
599
600
  print ('ROI->ED->ES train dataset')
601
  if not os.path.exists(out_path_train):
602
      os.makedirs(out_path_train)
603
      os.makedirs(out_path_test)
604
      
605
  train_dataset = convert_nii_np(train_dataset, mode='train', roi_detect=True)
606
  save_data(train_dataset, 'train_set.pkl', out_path_train)
607
  print("---Processing Training dataset %s seconds ---" % (time.time() - start_time))
608
  validation_dataset = convert_nii_np(validation_dataset, mode='train', roi_detect=True)
609
  save_data(validation_dataset, 'validation_set.pkl', out_path_train)
610
  print("---Processing Training dataset %s seconds ---" % (time.time() - start_time))
611
  test_dataset = convert_nii_np(test_dataset, mode='train', roi_detect=True)
612
  save_data(test_dataset, 'test_set.pkl', out_path_train)
613
  print("---Processing Training dataset %s seconds ---" % (time.time() - start_time))
614
615
  print ('ROI->ED->ES test dataset')
616
  final_test_dataset = convert_nii_np(final_testing_dataset, mode='test', roi_detect=True)
617
  save_data(final_test_dataset, 'final_testing_data.pkl', out_path_test)
618
  print("---Processing final testing dataset %s seconds ---" % (time.time() - start_time))
619
620
  # Generate 2D HDF5 files
621
  modes = ['train_set', 'validation_set', 'test_set']
622
  for mode in modes: 
623
      if os.path.exists(os.path.join(hdf5_out_path, mode)):
624
          shutil.rmtree(os.path.join(hdf5_out_path, mode))
625
      os.makedirs(os.path.join(hdf5_out_path, mode))
626
      patient_data = load_pkl(os.path.join(out_path_train, mode+'.pkl'))
627
      for patient_id in tqdm(patient_data.keys()):
628
      #     print (patient_id)
629
          _id = patient_id[-3:]
630
          n_slices = patient_data[patient_id]['ED_VOL'].shape[2]
631
  #         print (n_slices)
632
          for slice in range(n_slices):
633
  #           ED frames
634
              group = patient_data[patient_id]['Group']
635
              slice_str ='_%02d_'%slice
636
              roi_center = (patient_data[patient_id]['roi_center'][1], patient_data[patient_id]['roi_center'][0])
637
              hp = h5py.File(os.path.join(hdf5_out_path, mode, 'P_'+_id+'_ED'+slice_str+group+'.hdf5'),'w')
638
              hp.create_dataset('image', data=patient_data[patient_id]['ED_VOL'][:,:,slice].T)
639
              hp.create_dataset('label', data=patient_data[patient_id]['ED_GT'][:,:,slice].T)
640
              hp.create_dataset('roi_center', data=roi_center)
641
              hp.create_dataset('roi_radii', data=patient_data[patient_id]['roi_radii'])
642
              hp.create_dataset('pixel_spacing', data=patient_data[patient_id]['header']['hdr'].get_zooms())
643
              hp.close()
644
  #           ES frames
645
              hp = h5py.File(os.path.join(hdf5_out_path, mode, 'P_'+_id+'_ES'+slice_str+group+'.hdf5'),'w')
646
              hp.create_dataset('image', data=patient_data[patient_id]['ES_VOL'][:,:,slice].T)
647
              hp.create_dataset('label', data=patient_data[patient_id]['ES_GT'][:,:,slice].T)
648
              hp.create_dataset('roi_center', data=roi_center)
649
              hp.create_dataset('roi_radii', data=patient_data[patient_id]['roi_radii'])
650
              hp.create_dataset('pixel_spacing', data=patient_data[patient_id]['header']['hdr'].get_zooms())
651
              hp.close()    
652
  print("---Time taken to generate hdf5 files %s seconds ---" % (time.time() - start_time))