Diff of /pipeline/dwi_masking.py [000000] .. [e918fa]

Switch to unified view

a b/pipeline/dwi_masking.py
1
#!/usr/bin/env python
2
3
from __future__ import division
4
5
"""
6
pipeline.py
7
~~~~~~~~~~
8
01)  Accepts the diffusion image in *.nhdr,*.nrrd,*.nii.gz,*.nii format
9
02)  Checks if the Image axis is in the correct order for *.nhdr and *.nrrd file
10
03)  Extracts b0 Image
11
04)  Converts nhdr to nii.gz
12
05)  Applies rigid-body tranformation to standard MNI space using
13
06)  Normalize the Image by 99th percentile
14
07)  Predicts neural network brain mask across the 3 principal axes
15
08)  Performs multi-view aggregation
16
10)  Applies inverse tranformation
17
10)  Cleans holes
18
"""
19
20
# pylint: disable=invalid-name
21
import os
22
from os import path
23
import webbrowser
24
import multiprocessing as mp
25
import sys
26
from glob import glob
27
import subprocess
28
import argparse
29
import datetime
30
import pathlib
31
import nibabel as nib
32
import numpy as np
33
34
# Suppress tensor flow message
35
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
36
37
# Set CUDA_DEVICE_ORDER so the IDs assigned by CUDA match those from nvidia-smi
38
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
39
40
# Get the first available GPU
41
try:
42
    import GPUtil
43
44
    DEVICE_ID_LIST = [g.id for g in GPUtil.getGPUs()]
45
    DEVICE_ID = DEVICE_ID_LIST[0]
46
47
    CUDA_VISIBLE_DEVICES = os.getenv('CUDA_VISIBLE_DEVICES')
48
    if CUDA_VISIBLE_DEVICES:
49
        # prioritize external definition
50
        if int(CUDA_VISIBLE_DEVICES) in DEVICE_ID_LIST:
51
            pass
52
        else:
53
            # define it internally
54
            CUDA_VISIBLE_DEVICES = DEVICE_ID
55
    else:
56
        # define it internally
57
        CUDA_VISIBLE_DEVICES = DEVICE_ID
58
59
    os.environ["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES)
60
    # setting of CUDA_VISIBLE_DEVICES also masks out all other GPUs
61
62
    print("Use GPU device #", CUDA_VISIBLE_DEVICES)
63
64
except:
65
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
66
    print("GPU not available...")
67
68
69
import warnings
70
71
with warnings.catch_warnings():
72
    warnings.filterwarnings("ignore", category=FutureWarning)
73
    import tensorflow as tf
74
    try:
75
        from keras.models import model_from_json
76
        from keras import backend as K
77
        from keras.optimizers import Adam
78
    except ImportError:
79
        from tensorflow.keras.models import model_from_json
80
        from tensorflow.keras import backend as K
81
        from tensorflow.keras.optimizers import Adam
82
83
    # check version of tf and if 1.12 or less use tf.logging.set_verbosity(tf.logging.ERROR)
84
    if int(tf.__version__.split('.')[0]) <= 1 and int(tf.__version__.split('.')[1]) <= 12:
85
        tf.logging.set_verbosity(tf.logging.ERROR)
86
        # Configure for dynamic GPU memory usage
87
        config = tf.ConfigProto()
88
        config.gpu_options.allow_growth = True
89
        config.log_device_placement = False
90
        sess = tf.Session(config=config)
91
        K.set_session(sess)
92
    else:
93
        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
94
        gpus = tf.config.experimental.list_physical_devices('GPU')
95
        if gpus:
96
            try:
97
                # Currently, memory growth needs to be the same across GPUs
98
                for gpu in gpus:
99
                    tf.config.experimental.set_memory_growth(gpu, True)
100
            except RuntimeError as e:
101
                # Memory growth must be set before GPUs have been initialized
102
                print(e)
103
104
# suffixes
105
SUFFIX_NIFTI = "nii"
106
SUFFIX_NIFTI_GZ = "nii.gz"
107
SUFFIX_NRRD = "nrrd"
108
SUFFIX_NHDR = "nhdr"
109
SUFFIX_NPY = "npy"
110
SUFFIX_TXT = "txt"
111
output_mask = []
112
113
114
def predict_mask(input_file, trained_folder, view='default'):
115
    """
116
    Parameters
117
    ----------
118
    input_file : str
119
                 (single case filename which is stored in disk in *.nii.gz format) or 
120
                 (list of cases, all appended to 3d numpy array stored in disk in *.npy format)
121
    view       : str
122
                 Three principal axes ( Sagittal, Coronal and Axial )
123
    
124
    Returns
125
    -------
126
    output_file : str
127
                  returns the neural network predicted filename which is stored
128
                  in disk in 3d numpy array *.npy format
129
    """
130
    print("Loading " + view + " model from disk...")
131
    smooth = 1.
132
133
    def dice_coef(y_true, y_pred):
134
        y_true_f = K.flatten(y_true)
135
        y_pred_f = K.flatten(y_pred)
136
        intersection = K.sum(y_true_f * y_pred_f)
137
        return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
138
139
    # Negative dice to obtain region of interest (ROI-Branch loss) 
140
    def dice_coef_loss(y_true, y_pred):
141
        return -dice_coef(y_true, y_pred)
142
143
    # Positive dice to minimize overlap with region of interest (Complementary branch (CO) loss)
144
    def neg_dice_coef_loss(y_true, y_pred):
145
        return dice_coef(y_true, y_pred)
146
147
    # load json and create model
148
    json_file = open(trained_folder + '/CompNetBasicModel.json', 'r')
149
    loaded_model_json = json_file.read()
150
    json_file.close()
151
    loaded_model = model_from_json(loaded_model_json)
152
153
    # load weights into new model
154
    optimal_model = glob(trained_folder + '/weights-' + view + '-improvement-*.h5')[-1]
155
    loaded_model.load_weights(optimal_model)
156
157
    # check if tf2 or tf1, if tf 1 use lr instead of learning_rate
158
    if int(tf.__version__.split('.')[0]) <= 1:
159
        loaded_model.compile(optimizer=Adam(lr=1e-5),
160
                             loss={'final_op': dice_coef_loss,
161
                                   'xfinal_op': neg_dice_coef_loss,
162
                                   'res_1_final_op': 'mse'})
163
    else:
164
        loaded_model.compile(optimizer=Adam(learning_rate=1e-5),
165
                             loss={'final_op': dice_coef_loss,
166
                                   'xfinal_op': neg_dice_coef_loss,
167
                                   'res_1_final_op': 'mse'})
168
169
    case_name = path.basename(input_file)
170
    output_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-' + view + '-mask.npy'
171
    output_file = path.join(path.dirname(input_file), output_name)
172
173
    x_test = np.load(input_file)
174
    x_test = x_test.reshape(x_test.shape + (1,))
175
    predict_x = loaded_model.predict(x_test, verbose=1)
176
    SO = predict_x[0]  # Segmentation Output
177
    del predict_x
178
    np.save(output_file, SO)
179
    return output_file
180
181
182
def multi_view_fast(sagittal_SO, coronal_SO, axial_SO, input_file):
183
    x = np.load(sagittal_SO)
184
    y = np.load(coronal_SO)
185
    z = np.load(axial_SO)
186
187
    m, n = x.shape[::2]
188
    x = x.transpose(0, 3, 1, 2).reshape(m, -1, n)
189
190
    m, n = y.shape[::2]
191
    y = y.transpose(0, 3, 1, 2).reshape(m, -1, n)
192
193
    m, n = z.shape[::2]
194
    z = z.transpose(0, 3, 1, 2).reshape(m, -1, n)
195
196
    x = np.multiply(x, 0.1)
197
    y = np.multiply(y, 0.4)
198
    z = np.multiply(z, 0.5)
199
200
    print("Performing Muti View Aggregation...")
201
    XplusY = np.add(x, y)
202
    multi_view = np.add(XplusY, z)
203
    multi_view[multi_view > 0.45] = 1
204
    multi_view[multi_view <= 0.45] = 0
205
206
    case_name = path.basename(input_file)
207
    output_name = case_name[:len(case_name) - (len(SUFFIX_NHDR) + 1)] + '-multi-mask.npy'
208
    output_file = path.join(path.dirname(input_file), output_name)
209
210
    SO = multi_view.astype('float32')
211
    np.save(output_file, SO)
212
    return output_file
213
214
215
def normalize(b0_resampled, percentile):
216
    """
217
    Intensity based segmentation of MR images is hampered by radio frerquency field
218
    inhomogeneity causing intensity variation. The intensity range is typically
219
    scaled between the highest and lowest signal in the Image. Intensity values
220
    of the same tissue can vary between scans. The pixel value in images must be
221
    scaled prior to providing the images as input to CNN. The data is projected in to
222
    a predefined range [0,1]
223
224
    Parameters
225
    ---------
226
    b0_resampled : str
227
                   Accepts b0 resampled filename in *.nii.gz format
228
    Returns
229
    --------
230
    output_file : str
231
                  Normalized by 99th percentile filename which is stored in disk
232
    """
233
    print("Normalizing input data")
234
235
    input_file = b0_resampled
236
    case_name = path.basename(input_file)
237
    output_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-normalized.nii.gz'
238
    output_file = path.join(path.dirname(input_file), output_name)
239
    img = nib.load(b0_resampled)
240
    imgU16 = img.get_fdata().astype(np.float32)
241
    p = np.percentile(imgU16, percentile)
242
    data = imgU16 / p
243
    data[data > 1] = 1
244
    data[data < 0] = 0
245
    image_dwi = nib.Nifti1Image(data, img.affine, img.header)
246
    nib.save(image_dwi, output_file)
247
    return output_file
248
249
250
def save_nifti(fname, data, affine=None, hdr=None):
251
    hdr.set_data_dtype('int16')
252
    result_img = nib.Nifti1Image(data, affine, header=hdr)
253
    result_img.to_filename(fname)
254
255
256
def npy_to_nifti(b0_normalized_cases, cases_mask_arr, sub_name, view='default', reference='default', omat=None):
257
    """
258
    Parameters
259
    ---------
260
    b0_normalized_cases : str or list
261
                          str  (b0 normalized single filename which is in *.nii.gz format)
262
                          list (b0 normalized list of filenames which is in *.nii.gz format)
263
    case_mask_arr       : str or list
264
                          str  (single predicted mask filename which is in 3d numpy *.npy format)
265
                          list (list of predicted mask filenames which is in 3d numpy *.npy format)
266
    sub_name            : str or list
267
                          str  (single input case filename which is in *.nhdr format)
268
                          list (list of input case filename which is in *.nhdr format)
269
    view                : str
270
                          Three principal axes ( Sagittal, Coronal and Axial )
271
272
    reference           : str or list
273
                          str  (Linear-normalized case name which is in *.nii.gz format. 
274
                                This is the file before the rigid-body transformation step)
275
    Returns
276
    --------
277
    output_mask         : str or list
278
                          str  (single brain mask filename which is stored in disk in *.nhdr format)
279
                          list (list of brain mask for all cases which is stored in disk in *.nhdr format)
280
    """
281
    print("Converting file format...")
282
    global output_mask
283
    output_mask = []
284
    for i in range(0, len(b0_normalized_cases)):
285
        image_space = nib.load(b0_normalized_cases[i])
286
        predict = np.load(cases_mask_arr[i])
287
        predict[predict >= 0.5] = 1
288
        predict[predict < 0.5] = 0
289
        predict = predict.astype('int16')
290
        image_predict = nib.Nifti1Image(predict, image_space.affine, image_space.header)
291
        output_dir = path.dirname(sub_name[i])
292
        output_file = cases_mask_arr[i][:len(cases_mask_arr[i]) - len(SUFFIX_NPY)] + 'nii.gz'
293
        nib.save(image_predict, output_file)
294
295
        output_file_inverseMask = ANTS_inverse_transform(output_file, reference[i], omat[i])
296
        Ants_inverse_output_file = output_file_inverseMask
297
298
        case_name = path.basename(Ants_inverse_output_file)
299
        fill_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-filled.nii.gz'
300
        filled_file = path.join(output_dir, fill_name)
301
        fill_cmd = "ImageMath 3 " + filled_file + " FillHoles " + Ants_inverse_output_file
302
        process = subprocess.Popen(fill_cmd.split(), stdout=subprocess.PIPE)
303
        output, error = process.communicate()
304
305
        subject_name = path.basename(sub_name[i])
306
        if subject_name.endswith(SUFFIX_NIFTI_GZ):
307
            format = SUFFIX_NIFTI_GZ
308
        else:
309
            format = SUFFIX_NIFTI
310
311
        # Neural Network Predicted Mask
312
        CNN_predict_file = subject_name[:len(subject_name) - (len(format) + 1)] + '-' + view + '_originalMask.nii.gz'
313
        CNN_output_file = path.join(output_dir, CNN_predict_file)
314
        bashCommand = 'cp ' + filled_file + " " + CNN_output_file
315
        process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
316
        output, error = process.communicate()
317
318
        output_filter_file = subject_name[:len(subject_name) - (len(format) + 1)] + '-' + view + '_FilteredMask.nii.gz'
319
        output_mask_filtered = path.join(output_dir, output_filter_file)
320
321
        if args.filter:
322
            print('Cleaning up ', CNN_output_file)
323
324
            if args.filter == 'mrtrix':
325
                mask_filter = "maskfilter -force " + CNN_output_file + " -scale 2 clean " + output_mask_filtered
326
327
            elif args.filter == 'scipy':
328
                mask_filter = path.join(path.dirname(__file__), '../src/maskfilter.py') + f' {CNN_output_file} 2 {output_mask_filtered}'
329
330
            process = subprocess.Popen(mask_filter.split(), stdout=subprocess.PIPE)
331
            output, error = process.communicate()
332
333
        else:
334
            output_mask_filtered = CNN_output_file
335
336
        print(output_mask_filtered)
337
        img = nib.load(output_mask_filtered)
338
        data_dwi = nib.load(sub_name[i])
339
        imgU16 = img.get_fdata().astype(np.uint8)
340
341
        brain_mask_file = subject_name[:len(subject_name) - (len(format) + 1)] + '-' + view + '_BrainMask.nii.gz'
342
        brain_mask_final = path.join(output_dir, brain_mask_file)
343
344
        save_nifti(brain_mask_final, imgU16, affine=data_dwi.affine, hdr=data_dwi.header)
345
        output_mask.append(brain_mask_final)
346
347
    return output_mask
348
349
350
def clear(directory):
351
    print("Cleaning files ...")
352
353
    bin_a = 'cases_' + str(os.getpid()) + '_binary_a'
354
    bin_s = 'cases_' + str(os.getpid()) + '_binary_s'
355
    bin_c = 'cases_' + str(os.getpid()) + '_binary_c'
356
357
    for filename in os.listdir(directory):
358
        if filename.startswith('Comp') | filename.endswith(SUFFIX_NPY) | \
359
                filename.endswith('_SO.nii.gz') | filename.endswith('downsampled.nii.gz') | \
360
                filename.endswith('-thresholded.nii.gz') | filename.endswith('-inverse.mat') | \
361
                filename.endswith('-Warped.nii.gz') | filename.endswith('-0GenericAffine.mat') | \
362
                filename.endswith('_affinedMask.nii.gz') | filename.endswith('_originalMask.nii.gz') | \
363
                filename.endswith('multi-mask.nii.gz') | filename.endswith('-mask-inverse.nii.gz') | \
364
                filename.endswith('-InverseWarped.nii.gz') | filename.endswith('-FilteredMask.nii.gz') | \
365
                filename.endswith(bin_a) | filename.endswith(bin_c) | filename.endswith(bin_s) | \
366
                filename.endswith('_FilteredMask.nii.gz') | filename.endswith('-normalized.nii.gz') | filename.endswith('-filled.nii.gz'):
367
            os.unlink(directory + '/' + filename)
368
369
370
def split(cases_file, case_arr, view='default'):
371
    """
372
    Parameters
373
    ---------
374
    cases_file : str
375
                 Accepts a filename which is in 3d numpy array format stored in disk
376
    split_dim  : list
377
                 Contains the "x" dim for all the cases
378
    case_arr   : list
379
                 Contain filename for all the input cases
380
    Returns
381
    --------
382
    predict_mask : list
383
                   Contains the predicted mask filename of all the cases which is stored in disk in *.npy format
384
    """
385
386
    count = 0
387
    start = 0
388
    end = start + 256
389
    SO = np.load(cases_file)
390
391
    predict_mask = []
392
    for i in range(0, len(case_arr)):
393
        end = start + 256
394
        casex = SO[start:end, :, :]
395
        if view == 'coronal':
396
            casex = np.swapaxes(casex, 0, 1)
397
        elif view == 'axial':
398
            casex = np.swapaxes(casex, 0, 2)
399
        input_file = str(case_arr[i])
400
        output_file = input_file[:len(input_file) - (len(SUFFIX_NHDR) + 1)] + '-' + view + '_SO.npy'
401
        predict_mask.append(output_file)
402
        np.save(output_file, casex)
403
        start = end
404
        count += 1
405
406
    return predict_mask
407
408
409
def ANTS_rigid_body_trans(b0_nii, reference=None):
410
    print("Performing ants rigid body transformation...")
411
    input_file = b0_nii
412
    case_name = path.basename(input_file)
413
    output_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-'
414
    output_file = path.join(path.dirname(input_file), output_name)
415
416
    trans_matrix = "antsRegistrationSyNQuick.sh -d 3 -f " + reference + " -m " + input_file + " -t r -o " + output_file
417
    output1 = subprocess.check_output(trans_matrix, shell=True)
418
419
    omat_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-0GenericAffine.mat'
420
    omat_file = path.join(path.dirname(input_file), omat_name)
421
422
    output_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-Warped.nii.gz'
423
    transformed_file = path.join(path.dirname(input_file), output_name)
424
425
    return (transformed_file, omat_file)
426
427
428
def ANTS_inverse_transform(predicted_mask, reference, omat='default'):
429
430
    print("Performing ants inverse transform...")
431
    input_file = predicted_mask
432
    case_name = path.basename(input_file)
433
    output_name = case_name[:len(case_name) - (len(SUFFIX_NIFTI_GZ) + 1)] + '-inverse.nii.gz'
434
    output_file = path.join(path.dirname(input_file), output_name)
435
436
    # reference is the original b0 volume
437
    apply_inverse_trans = "antsApplyTransforms -d 3 -i " + predicted_mask + " -r " + reference + " -o " \
438
                          + output_file + " --transform [" + omat + ",1]"
439
440
    output2 = subprocess.check_output(apply_inverse_trans, shell=True)
441
    return output_file
442
443
444
def str2bool(v):
445
    if isinstance(v, bool):
446
        return v
447
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
448
        return True
449
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
450
        return False
451
    else:
452
        raise argparse.ArgumentTypeError('Boolean value expected...')
453
454
455
def list_masks(mask_list, view='default'):
456
    for i in range(0, len(mask_list)):
457
        print(view + " Mask file = ", mask_list[i])
458
459
460
def pre_process(input_file, b0_threshold=50.):
461
    from conversion import nifti_write, read_bvals
462
463
    if path.isfile(input_file):
464
465
        # convert NRRD/NHDR to NIFIT as the first step
466
        # extract bse.py from just NIFTI later
467
        if input_file.endswith(SUFFIX_NRRD) | input_file.endswith(SUFFIX_NHDR):
468
            inPrefix = input_file.split('.')[0]
469
            nifti_write(input_file)
470
            input_file = inPrefix + '.nii.gz'
471
472
        inPrefix = input_file.split('.nii')[0]
473
        b0_nii = path.join(inPrefix + '_bse.nii.gz')
474
475
        dwi = nib.load(input_file)
476
477
        if len(dwi.shape) > 3:
478
            print("Extracting b0 volume...")
479
            bvals = np.array(read_bvals(input_file.split('.nii')[0] + '.bval'))
480
            where_b0 = np.where(bvals <= b0_threshold)[0]
481
            b0 = dwi.get_fdata()[..., where_b0].mean(-1)
482
        else:
483
            print("Loading b0 volume...")
484
            b0 = dwi.get_fdata()
485
486
        np.nan_to_num(b0).clip(min=0., out=b0)
487
        nib.Nifti1Image(b0, affine=dwi.affine, header=dwi.header).to_filename(b0_nii)
488
489
        return b0_nii
490
491
    else:
492
        print("File not found ", input_file)
493
        sys.exit(1)
494
495
496
def remove_string(input_file, output_file, string):
497
    infile = input_file
498
    outfile = output_file
499
    delete_list = [string]
500
    fin = open(infile)
501
    fout = open(outfile, "w+")
502
    for line in fin:
503
        for word in delete_list:
504
            line = line.replace(word, "")
505
        fout.write(line)
506
    fin.close()
507
    fout.close()
508
509
510
def quality_control(mask_list, target_list, tmp_path, view='default'):
511
    '''The slicesdir command takes the list of images and creates a simple web-page containing snapshots for each of the images.
512
    Once it has finished running it tells you the name of the web page to open in your web browser, to view the snapshots.
513
    '''
514
515
    slices = " "
516
    for i in range(0, len(mask_list)):
517
        str1 = target_list[i]
518
        str2 = mask_list[i]
519
        slices += path.basename(str1) + " " + path.basename(str2) + " "
520
521
    final = "slicesdir -o" + slices
522
    dir_bak = os.getcwd()
523
    os.chdir(tmp_path)
524
525
    process = subprocess.Popen(final, shell=True)
526
    process.wait()
527
    os.chdir(dir_bak)
528
529
    mask_folder = os.path.join(tmp_path, 'slicesdir')
530
    mask_newfolder = os.path.join(tmp_path, 'slicesdir_' + view)
531
    if os.path.exists(mask_newfolder):
532
        process = subprocess.Popen('rm -rf ' + mask_newfolder, shell=True)
533
        process.wait()
534
535
    process = subprocess.Popen('mv ' + mask_folder + " " + mask_newfolder, shell=True)
536
    process.wait()
537
538
539
if __name__ == '__main__':
540
541
    start_total_time = datetime.datetime.now()
542
    
543
    parser = argparse.ArgumentParser()
544
545
    parser.add_argument('-i', action='store', dest='dwi', type=str,
546
                        help="txt file containing list of /path/to/dwi or /path/to/b0, one path in each line")
547
548
    parser.add_argument('-f', action='store', dest='model_folder', type=str,
549
                        help="folder containing the trained models")
550
551
    parser.add_argument("-a", type=str2bool, dest='Axial', nargs='?',
552
                        const=True, default=False,
553
                        help="advanced option to generate multiview and axial Mask (yes/true/y/1)")
554
555
    parser.add_argument("-c", type=str2bool, dest='Coronal', nargs='?',
556
                        const=True, default=False,
557
                        help="advanced option to generate multiview and coronal Mask (yes/true/y/1)")
558
559
    parser.add_argument("-s", type=str2bool, dest='Sagittal', nargs='?',
560
                        const=True, default=False,
561
                        help="advanced option to generate multiview and sagittal Mask (yes/true/y/1)")
562
563
    parser.add_argument("-qc", type=str2bool, dest='snap', nargs='?',
564
                        const=True, default=False,
565
                        help="advanced option to take snapshots and open them in your web browser (yes/true/y/1)")
566
567
    parser.add_argument('-p', type=int, dest='percentile', default=99, help='''percentile of image
568
intensity value to be used as a threshold for normalizing a b0 image to [0,1]''')
569
570
    parser.add_argument('-nproc', type=int, dest='nproc', default=1, help='number of processes to use')
571
572
    parser.add_argument('-filter', choices=['scipy', 'mrtrix'], help='''perform morphological operation on the 
573
CNN generated mask to clean up holes and islands, can be done through a provided script (scipy) 
574
or MRtrix3 maskfilter (mrtrix)''')
575
576
    try:
577
        args = parser.parse_args()
578
        if len(sys.argv) == 1:
579
            parser.print_help()
580
            parser.error('too few arguments')
581
            sys.exit(0)
582
    except SystemExit:
583
        sys.exit(0)
584
585
    if args.dwi:
586
        f = pathlib.Path(args.dwi)
587
        if f.exists():
588
            print("File exist")
589
            filename = args.dwi
590
        else:
591
            print("File not found")
592
            sys.exit(1)
593
594
        # Input caselist.txt
595
        if filename.endswith(SUFFIX_TXT):
596
            with open(filename) as f:
597
                case_arr = f.read().splitlines()
598
599
            TXT_file = path.basename(filename)
600
            unique = TXT_file[:len(TXT_file) - (len(SUFFIX_TXT) + 1)]
601
            storage = path.dirname(case_arr[0])
602
            tmp_path = storage + '/'
603
            if not args.model_folder:
604
                trained_model_folder = path.abspath(path.dirname(__file__)+'/../model_folder')
605
            else:
606
                trained_model_folder = args.model_folder
607
            reference = trained_model_folder + '/IITmean_b0_256.nii.gz'
608
609
            binary_file_s = storage + '/' + unique + '_' + str(os.getpid()) + '_binary_s'
610
            binary_file_c = storage + '/' + unique + '_' + str(os.getpid()) + '_binary_c'
611
            binary_file_a = storage + '/' + unique + '_' + str(os.getpid()) + '_binary_a'
612
613
            f_handle_s = open(binary_file_s, 'wb')
614
            f_handle_c = open(binary_file_c, 'wb')
615
            f_handle_a = open(binary_file_a, 'wb')
616
617
            x_dim = 0
618
            y_dim = 256
619
            z_dim = 256
620
            transformed_cases = []
621
622
            if args.nproc==1:
623
624
                target_list=[]
625
                for case in case_arr:
626
                    target_list.append(pre_process(case))
627
628
                result=[]
629
                for target in target_list:
630
                    result.append(ANTS_rigid_body_trans(target,reference))
631
632
                data_n=[]
633
                for transformed_case, _ in result:
634
                    data_n.append(normalize(transformed_case, args.percentile))
635
636
            else:
637
                with mp.Pool(processes=args.nproc) as pool:
638
                    res=[]
639
                    for case in case_arr:
640
                        res.append(pool.apply_async(pre_process, (case,)))
641
642
                    target_list=[r.get() for r in res]
643
                    pool.close()
644
                    pool.join()
645
646
647
                with mp.Pool(processes=args.nproc) as pool:
648
                    res=[]
649
                    for target in target_list:
650
                        res.append(pool.apply_async(ANTS_rigid_body_trans, (target, reference,)))
651
652
                    result=[r.get() for r in res]
653
                    pool.close()
654
                    pool.join()
655
656
657
                with mp.Pool(processes=args.nproc) as pool:
658
                    res=[]
659
                    for transformed_case, _ in result:
660
                        res.append(pool.apply_async(normalize, (transformed_case, args.percentile,)))
661
662
                    data_n=[r.get() for r in res]
663
                    pool.close()
664
                    pool.join()
665
666
667
            count = 0
668
            for b0_nifti in data_n:
669
                img = nib.load(b0_nifti)
670
                # sagittal view
671
                imgU16_sagittal = img.get_fdata().astype(np.float32)
672
                # coronal view
673
                imgU16_coronal = np.swapaxes(imgU16_sagittal, 0, 1)
674
                # axial view
675
                imgU16_axial = np.swapaxes(imgU16_sagittal, 0, 2)
676
677
                imgU16_sagittal.tofile(f_handle_s)
678
                imgU16_coronal.tofile(f_handle_c)
679
                imgU16_axial.tofile(f_handle_a)
680
                count += 1
681
                print("Case completed = ", count)
682
683
            f_handle_s.close()
684
            f_handle_c.close()
685
            f_handle_a.close()
686
687
            print("Merging npy files...")
688
            cases_file_s = storage + '/' + unique + '_' + str(os.getpid()) + '-casefile-sagittal.npy'
689
            cases_file_c = storage + '/' + unique + '_' + str(os.getpid()) + '-casefile-coronal.npy'
690
            cases_file_a = storage + '/' + unique + '_' + str(os.getpid()) + '-casefile-axial.npy'
691
692
            merged_dwi_list = []
693
            merged_dwi_list.append(cases_file_s)
694
            merged_dwi_list.append(cases_file_c)
695
            merged_dwi_list.append(cases_file_a)
696
697
            merge_s = np.memmap(binary_file_s, dtype=np.float32, mode='r+', shape=(256 * len(target_list), y_dim, z_dim))
698
            merge_c = np.memmap(binary_file_c, dtype=np.float32, mode='r+', shape=(256 * len(target_list), y_dim, z_dim))
699
            merge_a = np.memmap(binary_file_a, dtype=np.float32, mode='r+', shape=(256 * len(target_list), y_dim, z_dim))
700
701
            print("Saving data to disk...")
702
            np.save(cases_file_s, merge_s)
703
            np.save(cases_file_c, merge_c)
704
            np.save(cases_file_a, merge_a)
705
706
            normalized_file = storage + "/norm_cases_" + str(os.getpid()) + ".txt"
707
            registered_file = storage + "/ants_cases_" + str(os.getpid()) + ".txt"
708
            mat_file = storage + "/mat_cases_" + str(os.getpid()) + ".txt"
709
            target_file = storage + "/target_cases_" + str(os.getpid()) + ".txt"
710
711
            with open(normalized_file, "w") as norm_dwi:
712
                for item in data_n:
713
                    norm_dwi.write(item + "\n")
714
715
            remove_string(normalized_file, registered_file, "-normalized")
716
            remove_string(registered_file, target_file, "-Warped")
717
718
            with open(target_file) as f:
719
                newText = f.read().replace('.nii.gz', '-0GenericAffine.mat')
720
721
            with open(mat_file, "w") as f:
722
                f.write(newText)
723
724
            end_preprocessing_time = datetime.datetime.now()
725
            total_preprocessing_time = end_preprocessing_time - start_total_time
726
            print("Pre-Processing Time Taken : ", round(int(total_preprocessing_time.seconds) / 60, 2), " min")
727
728
            # DWI Deep Learning Segmentation
729
            dwi_mask_sagittal = predict_mask(cases_file_s, trained_model_folder, view='sagittal')
730
            dwi_mask_coronal = predict_mask(cases_file_c, trained_model_folder, view='coronal')
731
            dwi_mask_axial = predict_mask(cases_file_a, trained_model_folder, view='axial')
732
733
            end_masking_time = datetime.datetime.now()
734
            total_masking_time = end_masking_time - start_total_time - total_preprocessing_time
735
            print("Masking Time Taken : ", round(int(total_masking_time.seconds) / 60, 2), " min")
736
737
            transformed_file = registered_file
738
            omat_file = mat_file
739
740
            transformed_cases = [line.rstrip('\n') for line in open(transformed_file)]
741
            target_list = [line.rstrip('\n') for line in open(target_file)]
742
            omat_list = [line.rstrip('\n') for line in open(omat_file)]
743
744
            # Post Processing
745
            print("Splitting files....")
746
            cases_mask_sagittal = split(dwi_mask_sagittal, target_list, view='sagittal')
747
            cases_mask_coronal = split(dwi_mask_coronal, target_list, view='coronal')
748
            cases_mask_axial = split(dwi_mask_axial, target_list, view='axial')
749
750
            multi_mask = []
751
            for i in range(0, len(cases_mask_sagittal)):
752
                sagittal_SO = cases_mask_sagittal[i]
753
                coronal_SO = cases_mask_coronal[i]
754
                axial_SO = cases_mask_axial[i]
755
756
                input_file = target_list[i]
757
758
                multi_view_mask = multi_view_fast(sagittal_SO,
759
                                                  coronal_SO,
760
                                                  axial_SO,
761
                                                  input_file)
762
763
                brain_mask_multi = npy_to_nifti(list(transformed_cases[i].split()),
764
                                                list(multi_view_mask.split()),
765
                                                list(target_list[i].split()),
766
                                                view='multi',
767
                                                reference=list(target_list[i].split()),
768
                                                omat=list(omat_list[i].split()))
769
770
                print("Mask file : ", brain_mask_multi)
771
                multi_mask.append(brain_mask_multi[0])
772
            if args.snap:
773
                quality_control(multi_mask, target_list, tmp_path, view='multi')
774
775
            if args.Sagittal:
776
                omat = omat_list
777
            else:
778
                omat = None
779
780
            if args.Sagittal:
781
                sagittal_mask = npy_to_nifti(transformed_cases,
782
                                             cases_mask_sagittal,
783
                                             target_list,
784
                                             view='sagittal',
785
                                             reference=target_list,
786
                                             omat=omat)
787
                list_masks(sagittal_mask, view='sagittal')
788
                if args.snap:
789
                    quality_control(sagittal_mask, target_list, tmp_path, view='sagittal')
790
791
            if args.Coronal:
792
                omat = omat_list
793
            else:
794
                omat = None
795
796
            if args.Coronal:
797
                coronal_mask = npy_to_nifti(transformed_cases,
798
                                            cases_mask_coronal,
799
                                            target_list,
800
                                            view='coronal',
801
                                            reference=target_list,
802
                                            omat=omat)
803
                list_masks(coronal_mask, view='coronal')
804
                if args.snap:
805
                    quality_control(coronal_mask, target_list, tmp_path, view='coronal')
806
807
            if args.Axial:
808
                omat = omat_list
809
            else:
810
                omat = None
811
812
            if args.Axial:
813
                axial_mask = npy_to_nifti(transformed_cases,
814
                                          cases_mask_axial,
815
                                          target_list,
816
                                          view='axial',
817
                                          reference=target_list,
818
                                          omat=omat)
819
                list_masks(axial_mask, view='axial')
820
                if args.snap:
821
                    quality_control(axial_mask, target_list, tmp_path, view='axial')
822
823
            for i in range(0, len(cases_mask_sagittal)):
824
                clear(path.dirname(cases_mask_sagittal[i]))
825
826
            if args.snap:
827
                webbrowser.open(path.join(tmp_path, 'slicesdir_multi/index.html'))
828
                if args.Sagittal:
829
                    webbrowser.open(path.join(tmp_path, 'slicesdir_sagittal/index.html'))
830
                if args.Coronal:
831
                    webbrowser.open(path.join(tmp_path, 'slicesdir_coronal/index.html'))
832
                if args.Axial:
833
                    webbrowser.open(path.join(tmp_path, 'slicesdir_axial/index.html'))
834
835
        end_total_time = datetime.datetime.now()
836
        total_t = end_total_time - start_total_time
837
        print("Total Time Taken : ", round(int(total_t.seconds) / 60, 2), " min")