Diff of /src/postprocess.py [000000] .. [602ab8]

Switch to unified view

a b/src/postprocess.py
1
from __future__ import print_function
2
3
import os
4
import numpy as np
5
import nibabel as nib
6
import matplotlib.pyplot as plt
7
from multiprocessing import Pool, cpu_count
8
from scipy.ndimage.interpolation import zoom
9
10
11
def plot_middle(data, slice_no=None):
12
    if not slice_no:
13
        slice_no = data.shape[-1] // 2
14
    plt.figure()
15
    plt.imshow(data[..., slice_no], cmap="gray")
16
    plt.show()
17
    return
18
19
20
def create_dir(path):
21
    if not os.path.isdir(path):
22
        os.makedirs(path)
23
    return
24
25
26
def load_nii(path):
27
    return nib.load(path).get_data()
28
29
30
def save_nii(data, path):
31
    nib.save(nib.Nifti1Image(data, np.eye(4)), path)
32
    return
33
34
35
def brain_mask(data, mask):
36
    return np.multiply(data, mask)
37
38
39
def resize(data, target_shape=[96, 112, 96]):
40
    factor = [float(t) / float(s) for t, s in zip(target_shape, data.shape)]
41
    resized = zoom(data, zoom=factor, order=1, prefilter=False)
42
    return resized
43
44
45
def norm(data):
46
    # obj_idx = np.where(data > 0)
47
    # obj = data[obj_idx]
48
    # obj_mean, obj_std = np.mean(obj), np.std(obj)
49
    # obj = (obj - obj_mean) / obj_std
50
    # data[obj_idx] = obj
51
    data = data / float(np.max(data))
52
    return data
53
54
55
def unwarp_postprocess(arg, **kwarg):
56
    return postprocess(*arg, **kwarg)
57
58
59
def postprocess(src_path, dst_path, temp_path=None, is_mask=False):
60
    print("Wroking on: ", src_path)
61
    try:
62
        data = load_nii(src_path)
63
        if is_mask:
64
            mask = load_nii(temp_path)
65
            data = brain_mask(data, mask)
66
        data = resize(data)
67
        # data = norm(data)
68
        save_nii(data, dst_path)
69
    except RuntimeError:
70
        print("\tFalid on: ", src_path)
71
72
73
parent_dir = os.path.dirname(os.getcwd())
74
data_dir = os.path.join(parent_dir, "data")
75
data_src_dir = os.path.join(data_dir, "ADNISegment")
76
data_dst_dir = os.path.join(data_dir, "ADNISegmentPost")
77
data_labels = ["AD", "NC"]
78
create_dir(data_dst_dir)
79
80
data_src_paths, data_dst_paths = [], []
81
for label in data_labels:
82
    src_label_dir = os.path.join(data_src_dir, label)
83
    dst_label_dir = os.path.join(data_dst_dir, label)
84
    create_dir(dst_label_dir)
85
    for subject in os.listdir(src_label_dir):
86
        data_src_paths.append(os.path.join(src_label_dir, subject))
87
        data_dst_paths.append(os.path.join(dst_label_dir, subject))
88
89
temp_path = os.path.join(data_dir, "Template", "bianca_exclusion_mask.nii.gz")
90
91
# Test
92
# postprocess(data_src_paths[0], data_dst_paths[0], temp_path)
93
94
# Multi-processing
95
is_mask = False
96
subj_num = len(data_src_paths)
97
paras = zip(data_src_paths, data_dst_paths,
98
            [temp_path] * subj_num, [is_mask] * subj_num)
99
pool = Pool(processes=cpu_count())
100
pool.map(unwarp_postprocess, paras)