|
a |
|
b/src/postprocess.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import numpy as np |
|
|
5 |
import nibabel as nib |
|
|
6 |
import matplotlib.pyplot as plt |
|
|
7 |
from multiprocessing import Pool, cpu_count |
|
|
8 |
from scipy.ndimage.interpolation import zoom |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
def plot_middle(data, slice_no=None): |
|
|
12 |
if not slice_no: |
|
|
13 |
slice_no = data.shape[-1] // 2 |
|
|
14 |
plt.figure() |
|
|
15 |
plt.imshow(data[..., slice_no], cmap="gray") |
|
|
16 |
plt.show() |
|
|
17 |
return |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def create_dir(path): |
|
|
21 |
if not os.path.isdir(path): |
|
|
22 |
os.makedirs(path) |
|
|
23 |
return |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def load_nii(path): |
|
|
27 |
return nib.load(path).get_data() |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def save_nii(data, path): |
|
|
31 |
nib.save(nib.Nifti1Image(data, np.eye(4)), path) |
|
|
32 |
return |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def brain_mask(data, mask): |
|
|
36 |
return np.multiply(data, mask) |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def resize(data, target_shape=[96, 112, 96]): |
|
|
40 |
factor = [float(t) / float(s) for t, s in zip(target_shape, data.shape)] |
|
|
41 |
resized = zoom(data, zoom=factor, order=1, prefilter=False) |
|
|
42 |
return resized |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def norm(data): |
|
|
46 |
# obj_idx = np.where(data > 0) |
|
|
47 |
# obj = data[obj_idx] |
|
|
48 |
# obj_mean, obj_std = np.mean(obj), np.std(obj) |
|
|
49 |
# obj = (obj - obj_mean) / obj_std |
|
|
50 |
# data[obj_idx] = obj |
|
|
51 |
data = data / float(np.max(data)) |
|
|
52 |
return data |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
def unwarp_postprocess(arg, **kwarg): |
|
|
56 |
return postprocess(*arg, **kwarg) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
def postprocess(src_path, dst_path, temp_path=None, is_mask=False): |
|
|
60 |
print("Wroking on: ", src_path) |
|
|
61 |
try: |
|
|
62 |
data = load_nii(src_path) |
|
|
63 |
if is_mask: |
|
|
64 |
mask = load_nii(temp_path) |
|
|
65 |
data = brain_mask(data, mask) |
|
|
66 |
data = resize(data) |
|
|
67 |
# data = norm(data) |
|
|
68 |
save_nii(data, dst_path) |
|
|
69 |
except RuntimeError: |
|
|
70 |
print("\tFalid on: ", src_path) |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
parent_dir = os.path.dirname(os.getcwd()) |
|
|
74 |
data_dir = os.path.join(parent_dir, "data") |
|
|
75 |
data_src_dir = os.path.join(data_dir, "ADNISegment") |
|
|
76 |
data_dst_dir = os.path.join(data_dir, "ADNISegmentPost") |
|
|
77 |
data_labels = ["AD", "NC"] |
|
|
78 |
create_dir(data_dst_dir) |
|
|
79 |
|
|
|
80 |
data_src_paths, data_dst_paths = [], [] |
|
|
81 |
for label in data_labels: |
|
|
82 |
src_label_dir = os.path.join(data_src_dir, label) |
|
|
83 |
dst_label_dir = os.path.join(data_dst_dir, label) |
|
|
84 |
create_dir(dst_label_dir) |
|
|
85 |
for subject in os.listdir(src_label_dir): |
|
|
86 |
data_src_paths.append(os.path.join(src_label_dir, subject)) |
|
|
87 |
data_dst_paths.append(os.path.join(dst_label_dir, subject)) |
|
|
88 |
|
|
|
89 |
temp_path = os.path.join(data_dir, "Template", "bianca_exclusion_mask.nii.gz") |
|
|
90 |
|
|
|
91 |
# Test |
|
|
92 |
# postprocess(data_src_paths[0], data_dst_paths[0], temp_path) |
|
|
93 |
|
|
|
94 |
# Multi-processing |
|
|
95 |
is_mask = False |
|
|
96 |
subj_num = len(data_src_paths) |
|
|
97 |
paras = zip(data_src_paths, data_dst_paths, |
|
|
98 |
[temp_path] * subj_num, [is_mask] * subj_num) |
|
|
99 |
pool = Pool(processes=cpu_count()) |
|
|
100 |
pool.map(unwarp_postprocess, paras) |