a b/src/fast_segment.py
1
from __future__ import print_function
2
3
import os
4
import shutil
5
import subprocess
6
from multiprocessing import Pool, cpu_count
7
8
9
def create_dir(path):
10
    if not os.path.isdir(path):
11
        os.makedirs(path)
12
    return
13
14
15
def fast(src_path, dst_path, temp_dir, temp_path):
16
    command = ["fast", "-t", "1", "-n", "3", "-H", "0.1", "-I", "1", "-l", "20.0",
17
               "-o", temp_dir, src_path]
18
    subprocess.call(command, stdout=open(os.devnull), stderr=subprocess.STDOUT)
19
    shutil.copyfile(temp_path, dst_path)
20
    shutil.rmtree(os.path.dirname(temp_dir))
21
    return
22
23
24
def unwarp_segment(arg, **kwarg):
25
    return segment(*arg, **kwarg)
26
27
28
def segment(src_path, dst_path, temp_dir, temp_path):
29
    print("Segment on: ", src_path)
30
    try:
31
        fast(src_path, dst_path, temp_dir, temp_path)
32
    except RuntimeError:
33
        print("\tFalid on: ", src_path)
34
    return
35
36
37
parent_dir = os.path.dirname(os.getcwd())
38
data_dir = os.path.join(parent_dir, "data")
39
data_src_dir = os.path.join(data_dir, "ADNIDenoise")
40
data_dst_dir = os.path.join(data_dir, "ADNISegment")
41
data_labels = ["AD", "NC"]
42
create_dir(data_dst_dir)
43
44
data_src_paths, data_dst_paths = [], []
45
temp_dirs, temp_paths = [], []
46
for label in data_labels:
47
    src_label_dir = os.path.join(data_src_dir, label)
48
    dst_label_dir = os.path.join(data_dst_dir, label)
49
    create_dir(dst_label_dir)
50
    for subject in os.listdir(src_label_dir):
51
        data_src_paths.append(os.path.join(src_label_dir, subject))
52
        data_dst_paths.append(os.path.join(dst_label_dir, subject))
53
        subj_name = subject.split(".")[0]
54
        temp_dir = os.path.join(dst_label_dir, subj_name, subj_name)
55
        create_dir(os.path.dirname(temp_dir))
56
        temp_dirs.append(temp_dir)
57
        temp_paths.append(temp_dir + "_pve_1.nii.gz")
58
59
# Test
60
# print(data_src_paths[0], data_dst_paths[0],
61
#       temp_dirs[0], temp_paths[0])
62
# segment(data_src_paths[0], data_dst_paths[0],
63
#         temp_dirs[0], temp_paths[0])
64
65
# Multi-processing
66
subj_num = len(data_src_paths)
67
paras = zip(data_src_paths, data_dst_paths, temp_dirs, temp_paths)
68
pool = Pool(processes=cpu_count())
69
pool.map(unwarp_segment, paras)