Diff of /src/segment.py [000000] .. [602ab8]

Switch to unified view

a b/src/segment.py
1
from __future__ import print_function
2
3
import os
4
import numpy as np
5
import nibabel as nib
6
import skfuzzy as fuzz
7
from sklearn.cluster import KMeans
8
from multiprocessing import Pool, cpu_count
9
10
11
def create_dir(path):
12
    if not os.path.isdir(path):
13
        os.makedirs(path)
14
    return
15
16
17
def load_nii(path):
18
    nii = nib.load(path)
19
    return nii.get_data(), nii.get_affine()
20
21
22
def save_nii(data, path, affine):
23
    nib.save(nib.Nifti1Image(data, affine), path)
24
    return
25
26
27
def extract_features(data):
28
    x_idx, y_idx, z_idx = np.where(data > 0)
29
    features = []
30
    for x, y, z in zip(x_idx, y_idx, z_idx):
31
        features.append([data[x, y, z], x, y, z])
32
    return np.array(features)
33
34
35
def kmeans_cluster(data, n_clusters):
36
    features = extract_features(data)
37
    intensities = features[..., 0].reshape((-1, 1))
38
    kmeans_model = KMeans(n_clusters=n_clusters, init="k-means++",
39
                          precompute_distances=True, verbose=0,
40
                          random_state=7, n_jobs=1,
41
                          max_iter=1000, tol=1e-6).fit(intensities)
42
43
    labels = np.zeros(data.shape)
44
    for l, f in zip(kmeans_model.labels_, features):
45
        labels[int(f[1]), int(f[2]), int(f[3])] = l + 1
46
47
    return labels
48
49
50
def fuzzy_cmeans_cluster(data, n_clusters):
51
    features = extract_features(data)
52
    intensities = features[..., 0].reshape((1, -1))
53
54
    cntr, u, u0, d, jm, p, fpc = fuzz.cluster.cmeans(
55
        intensities, n_clusters, 2, error=1e-6,
56
        maxiter=1000, init=None, seed=7)
57
    labels_ = np.argmax(u, axis=0)
58
59
    labels = np.zeros(data.shape)
60
    for l, f in zip(labels_, features):
61
        labels[int(f[1]), int(f[2]), int(f[3])] = l + 1
62
63
    return labels
64
65
66
def get_target_label(labels, data):
67
    labels_set = np.unique(labels)
68
    mean_intensities = []
69
    for label in labels_set[1:]:
70
        label_data = data[np.where(labels == label)]
71
        mean_intensities.append(np.mean(label_data))
72
    target_intensity = np.median(mean_intensities)  # GM
73
    # target_intensity = np.max(mean_intensities)  # WM
74
    # target_intensity = np.min(mean_intensities)  # CSF
75
    target_label = mean_intensities.index(target_intensity) + 1
76
    return target_label
77
78
79
def unwarp_segment(arg, **kwarg):
80
    return segment(*arg, **kwarg)
81
82
83
def segment(src_path, dst_path, labels_path=None, method="km"):
84
    print("Segment on: ", src_path)
85
    try:
86
        data, affine = load_nii(src_path)
87
        n_clusters = 3
88
89
        if method == "km":
90
            # Method 1 - KMeans
91
            labels = kmeans_cluster(data, n_clusters)
92
        elif method == "fcm":
93
            # Method 2 - Fuzzy CMeans
94
            labels = fuzzy_cmeans_cluster(data, n_clusters)
95
96
        target = get_target_label(labels, data)
97
        gm_mask = np.copy(labels).astype(np.float32)
98
        gm_mask[np.where(gm_mask != target)] = 0.333
99
        gm_mask[np.where(gm_mask == target)] = 1.
100
        data = data.astype(np.float32)
101
        gm = np.multiply(data, gm_mask)
102
        save_nii(labels, labels_path, affine)
103
        save_nii(gm, dst_path, affine)
104
    except RuntimeError:
105
        print("\tFalid on: ", src_path)
106
107
    return
108
109
110
parent_dir = os.path.dirname(os.getcwd())
111
data_dir = os.path.join(parent_dir, "data")
112
data_src_dir = os.path.join(data_dir, "ADNIEnhance")
113
data_dst_dir = os.path.join(data_dir, "ADNIKMSegment")
114
data_labels = ["AD", "NC"]
115
create_dir(data_dst_dir)
116
117
data_src_paths, data_dst_paths, labels_paths = [], [], []
118
for label in data_labels:
119
    src_label_dir = os.path.join(data_src_dir, label)
120
    dst_label_dir = os.path.join(data_dst_dir, label)
121
    create_dir(dst_label_dir)
122
    for subject in os.listdir(src_label_dir):
123
        data_src_paths.append(os.path.join(src_label_dir, subject))
124
        subj_name = subject.split(".")[0]
125
        dst_subj_dir = os.path.join(dst_label_dir, subj_name)
126
        create_dir(dst_subj_dir)
127
        data_dst_paths.append(os.path.join(dst_subj_dir, subject))
128
        labels_paths.append(os.path.join(dst_subj_dir, subj_name + "_labels.nii.gz"))
129
130
method = "km"  # "fcm" or "km"
131
132
# Test
133
# segment(data_src_paths[0], data_dst_paths[0], labels_paths[0])
134
135
# Multi-processing
136
subj_num = len(data_src_paths)
137
paras = zip(data_src_paths, data_dst_paths,
138
            labels_paths, [method] * subj_num)
139
pool = Pool(processes=cpu_count())
140
pool.map(unwarp_segment, paras)