|
a |
|
b/src/fast_segment.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import shutil |
|
|
5 |
import subprocess |
|
|
6 |
from multiprocessing import Pool, cpu_count |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def create_dir(path): |
|
|
10 |
if not os.path.isdir(path): |
|
|
11 |
os.makedirs(path) |
|
|
12 |
return |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def fast(src_path, dst_path, temp_dir, temp_path): |
|
|
16 |
command = ["fast", "-t", "1", "-n", "3", "-H", "0.1", "-I", "1", "-l", "20.0", |
|
|
17 |
"-o", temp_dir, src_path] |
|
|
18 |
subprocess.call(command, stdout=open(os.devnull), stderr=subprocess.STDOUT) |
|
|
19 |
shutil.copyfile(temp_path, dst_path) |
|
|
20 |
shutil.rmtree(os.path.dirname(temp_dir)) |
|
|
21 |
return |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def unwarp_segment(arg, **kwarg): |
|
|
25 |
return segment(*arg, **kwarg) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def segment(src_path, dst_path, temp_dir, temp_path): |
|
|
29 |
print("Segment on: ", src_path) |
|
|
30 |
try: |
|
|
31 |
fast(src_path, dst_path, temp_dir, temp_path) |
|
|
32 |
except RuntimeError: |
|
|
33 |
print("\tFalid on: ", src_path) |
|
|
34 |
return |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
parent_dir = os.path.dirname(os.getcwd()) |
|
|
38 |
data_dir = os.path.join(parent_dir, "data") |
|
|
39 |
data_src_dir = os.path.join(data_dir, "ADNIDenoise") |
|
|
40 |
data_dst_dir = os.path.join(data_dir, "ADNISegment") |
|
|
41 |
data_labels = ["AD", "NC"] |
|
|
42 |
create_dir(data_dst_dir) |
|
|
43 |
|
|
|
44 |
data_src_paths, data_dst_paths = [], [] |
|
|
45 |
temp_dirs, temp_paths = [], [] |
|
|
46 |
for label in data_labels: |
|
|
47 |
src_label_dir = os.path.join(data_src_dir, label) |
|
|
48 |
dst_label_dir = os.path.join(data_dst_dir, label) |
|
|
49 |
create_dir(dst_label_dir) |
|
|
50 |
for subject in os.listdir(src_label_dir): |
|
|
51 |
data_src_paths.append(os.path.join(src_label_dir, subject)) |
|
|
52 |
data_dst_paths.append(os.path.join(dst_label_dir, subject)) |
|
|
53 |
subj_name = subject.split(".")[0] |
|
|
54 |
temp_dir = os.path.join(dst_label_dir, subj_name, subj_name) |
|
|
55 |
create_dir(os.path.dirname(temp_dir)) |
|
|
56 |
temp_dirs.append(temp_dir) |
|
|
57 |
temp_paths.append(temp_dir + "_pve_1.nii.gz") |
|
|
58 |
|
|
|
59 |
# Test |
|
|
60 |
# print(data_src_paths[0], data_dst_paths[0], |
|
|
61 |
# temp_dirs[0], temp_paths[0]) |
|
|
62 |
# segment(data_src_paths[0], data_dst_paths[0], |
|
|
63 |
# temp_dirs[0], temp_paths[0]) |
|
|
64 |
|
|
|
65 |
# Multi-processing |
|
|
66 |
subj_num = len(data_src_paths) |
|
|
67 |
paras = zip(data_src_paths, data_dst_paths, temp_dirs, temp_paths) |
|
|
68 |
pool = Pool(processes=cpu_count()) |
|
|
69 |
pool.map(unwarp_segment, paras) |