|
a |
|
b/tools/test_acdc_leadboard.py |
|
|
1 |
import os |
|
|
2 |
import sys |
|
|
3 |
import numpy as np |
|
|
4 |
import nibabel as nib |
|
|
5 |
import time |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import torch.nn.functional as F |
|
|
10 |
|
|
|
11 |
import _init_paths |
|
|
12 |
from libs.network import U_Net, U_NetDF |
|
|
13 |
from utils.image_list import to_image_list |
|
|
14 |
import libs.datasets.augment as standard_augment |
|
|
15 |
import libs.datasets.joint_augment as joint_augment |
|
|
16 |
|
|
|
17 |
def progress_bar(curr_idx, max_idx, time_step, repeat_elem = "_"): |
|
|
18 |
max_equals = 55 |
|
|
19 |
step_ms = int(time_step*1000) |
|
|
20 |
num_equals = int(curr_idx*max_equals/float(max_idx)) |
|
|
21 |
len_reverse =len('Step:%d ms| %d/%d ['%(step_ms, curr_idx, max_idx)) + num_equals |
|
|
22 |
sys.stdout.write("Step:%d ms|%d/%d [%s]" %(step_ms, curr_idx, max_idx, " " * max_equals,)) |
|
|
23 |
sys.stdout.flush() |
|
|
24 |
sys.stdout.write("/b" * (max_equals+1)) |
|
|
25 |
sys.stdout.write(repeat_elem * num_equals) |
|
|
26 |
sys.stdout.write("/b"*len_reverse) |
|
|
27 |
sys.stdout.flush() |
|
|
28 |
if curr_idx == max_idx: |
|
|
29 |
print('/n') |
|
|
30 |
|
|
|
31 |
def load_nii(img_path): |
|
|
32 |
""" |
|
|
33 |
Function to load a 'nii' or 'nii.gz' file, The function returns |
|
|
34 |
everyting needed to save another 'nii' or 'nii.gz' |
|
|
35 |
in the same dimensional space, i.e. the affine matrix and the header |
|
|
36 |
|
|
|
37 |
Parameters |
|
|
38 |
---------- |
|
|
39 |
|
|
|
40 |
img_path: string |
|
|
41 |
String with the path of the 'nii' or 'nii.gz' image file name. |
|
|
42 |
|
|
|
43 |
Returns |
|
|
44 |
------- |
|
|
45 |
Three element, the first is a numpy array of the image values, |
|
|
46 |
the second is the affine transformation of the image, and the |
|
|
47 |
last one is the header of the image. |
|
|
48 |
""" |
|
|
49 |
nimg = nib.load(img_path) |
|
|
50 |
return nimg.get_fdata(), nimg.affine, nimg.header |
|
|
51 |
|
|
|
52 |
def save_nii(vol, affine, hdr, path, prefix, suffix): |
|
|
53 |
vol = nib.Nifti1Image(vol, affine, hdr) |
|
|
54 |
vol.set_data_dtype(np.uint8) |
|
|
55 |
nib.save(vol, os.path.join(path, prefix+'_'+suffix + ".nii.gz")) |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def get_person_names(root_path=None): |
|
|
59 |
persons_name = os.listdir(root_path) |
|
|
60 |
persons_name = [pn for pn in persons_name if "patient" in pn] |
|
|
61 |
persons_name.sort() |
|
|
62 |
return persons_name |
|
|
63 |
|
|
|
64 |
def get_patient_data(patient, root_path): |
|
|
65 |
patient_data = {} |
|
|
66 |
infocfg_p = os.path.join(root_path, patient, "Info.cfg") |
|
|
67 |
|
|
|
68 |
with open(infocfg_p) as f_in: |
|
|
69 |
for line in f_in: |
|
|
70 |
l = line.rstrip().split(": ") |
|
|
71 |
patient_data[l[0]] = l[1] |
|
|
72 |
|
|
|
73 |
ed_path = os.path.join(root_path, patient, "%s_frame%02d.nii.gz" % (patient, int(patient_data['ED']))) |
|
|
74 |
es_path = os.path.join(root_path, patient, "%s_frame%02d.nii.gz" % (patient, int(patient_data['ES']))) |
|
|
75 |
img_4d_path = os.path.join(root_path, patient, "{}_4d.nii.gz".format(patient)) |
|
|
76 |
# ed_gt_path = os.path.join(root_path, patient, "%s_frame%02d_gt.nii.gz" % (patient, int(patient_data['ED']))) |
|
|
77 |
# es_gt_path = os.path.join(root_path, patient, "%s_frame%02d_gt.nii.gz" % (patient, int(patient_data['ES']))) |
|
|
78 |
|
|
|
79 |
ed, affine, hdr = load_nii(ed_path) |
|
|
80 |
patient_data['ED_VOL'] = np.swapaxes(ed, 0, -1) |
|
|
81 |
patient_data['3D_affine'] = affine |
|
|
82 |
patient_data['3D_hdr'] = hdr |
|
|
83 |
|
|
|
84 |
es, _, _ = load_nii(es_path) # (w, h, slices) |
|
|
85 |
patient_data['ES_VOL'] = np.swapaxes(es, 0, -1) |
|
|
86 |
|
|
|
87 |
img_4d, affine_4d, hdr_4d = load_nii(img_4d_path) # (w, h, slices, times) |
|
|
88 |
patient_data['4D'] = np.swapaxes(img_4d, 0, 1) |
|
|
89 |
patient_data['4D_affine'] = affine_4d |
|
|
90 |
patient_data['4D_hdr'] = hdr_4d |
|
|
91 |
|
|
|
92 |
patient_data['size'] = img_4d.shape[:2][::-1] |
|
|
93 |
patient_data['pid'] = patient |
|
|
94 |
|
|
|
95 |
# ed_gt = load_nii(ed_gt_path) |
|
|
96 |
# patient_data['ED_GT'] = np.swapaxes(ed_gt, 0, 1) |
|
|
97 |
|
|
|
98 |
# es_gt = load_nii(es_gt_path) |
|
|
99 |
# patient_data['ES_GT'] = np.swapaxes(es_gt, 0, 1) |
|
|
100 |
return patient_data |
|
|
101 |
|
|
|
102 |
def test_it(model, data, device="cuda", used_df=True): |
|
|
103 |
model.eval() |
|
|
104 |
imgs = data |
|
|
105 |
|
|
|
106 |
imgs = imgs.to(device) |
|
|
107 |
# gts = gts.to(device) |
|
|
108 |
|
|
|
109 |
net_out = model(imgs) |
|
|
110 |
if used_df: |
|
|
111 |
preds_out = net_out[0] |
|
|
112 |
preds_df = net_out[1] |
|
|
113 |
else: |
|
|
114 |
preds_out = net_out[0] |
|
|
115 |
preds_df = None |
|
|
116 |
preds_out = nn.functional.softmax(preds_out, dim=1) |
|
|
117 |
_, preds = torch.max(preds_out, 1) |
|
|
118 |
preds = preds.unsqueeze(1) # (N, 1, *) |
|
|
119 |
|
|
|
120 |
return preds, preds_df |
|
|
121 |
|
|
|
122 |
def transform(imgs): |
|
|
123 |
mean = 63.19523533061758 |
|
|
124 |
std = 70.74166957523165 |
|
|
125 |
trans = standard_augment.Compose([standard_augment.To_PIL_Image(), |
|
|
126 |
# joint_augment.RandomAffine(0,translate=(0.125, 0.125)), |
|
|
127 |
# joint_augment.RandomRotate((-180,180)), |
|
|
128 |
# joint_augment.FixResize(224), |
|
|
129 |
standard_augment.to_Tensor(), |
|
|
130 |
standard_augment.normalize([mean], [std]), |
|
|
131 |
]) |
|
|
132 |
return trans(imgs) |
|
|
133 |
|
|
|
134 |
def test_voxel(model, imgs, used_df, multi_batches=False, resize=None): |
|
|
135 |
""" imgs: (slices, H, W) |
|
|
136 |
preds: (slices, 1, H, W) |
|
|
137 |
""" |
|
|
138 |
imgs = imgs[..., None].astype(np.float32) |
|
|
139 |
B, _, _, C = imgs.shape |
|
|
140 |
|
|
|
141 |
if multi_batches: |
|
|
142 |
data, origin_shape = to_image_list(imgs, size_divisible=32, return_size=True) |
|
|
143 |
preds, _ = test_it(model, data) |
|
|
144 |
|
|
|
145 |
# for j in range(imgs.shape[0]): |
|
|
146 |
# preds[j, ...] = pred.cpu().numpy()[j, :, :origin_shape[j][0], :origin_shape[j][1]] |
|
|
147 |
else: |
|
|
148 |
preds = torch.zeros(B, C, resize[0], resize[1]) |
|
|
149 |
for j, pt in enumerate(imgs): |
|
|
150 |
data = [transform(pt)] |
|
|
151 |
data, origin_shape = to_image_list(data, size_divisible=32, return_size=True) |
|
|
152 |
pred, _ = test_it(model, data, used_df=used_df) |
|
|
153 |
preds[j, ...] = pred[0, 0, :origin_shape[0][0], :origin_shape[0][1]] |
|
|
154 |
|
|
|
155 |
if resize is not None: |
|
|
156 |
# preds = F.interpolate(preds, size=resize, mode='nearest') |
|
|
157 |
preds = preds[..., :resize[0], :resize[1]] |
|
|
158 |
|
|
|
159 |
return preds.cpu().numpy()[:, 0, ...] |
|
|
160 |
|
|
|
161 |
def create_model(model_name, selfeat): |
|
|
162 |
if model_name == 'U_NetDF': |
|
|
163 |
model = U_NetDF(selfeat=selfeat, auxseg=True) |
|
|
164 |
elif model_name == 'U_Net': |
|
|
165 |
model = U_Net() |
|
|
166 |
# elif model_name == 'Resnet18_DfUnet': |
|
|
167 |
# model = Resnet18_DfUnet() |
|
|
168 |
# elif model_name == 'DenseNet': |
|
|
169 |
# model = DenseNet() |
|
|
170 |
# elif model_name == 'DenseNet_DF': |
|
|
171 |
# model = DenseNet_DF(selfeat=selfeat) |
|
|
172 |
|
|
|
173 |
return model |
|
|
174 |
|
|
|
175 |
def test(mgpus, model_name, model_path, used_df, selfeat, log_path): |
|
|
176 |
|
|
|
177 |
model = create_model(model_name, selfeat=selfeat) |
|
|
178 |
if mgpus is not None and len(mgpus) > 2: |
|
|
179 |
model = nn.DataParallel(model) |
|
|
180 |
model.cuda() |
|
|
181 |
|
|
|
182 |
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
183 |
model.load_state_dict(checkpoint['model_state']) |
|
|
184 |
|
|
|
185 |
root_path = "MICCAIACDC2017/ACDC_DataSet/testing/testing/" |
|
|
186 |
root_path = "/root/ACDC_DataSet/testing/testing/" |
|
|
187 |
persons_name = get_person_names(root_path) |
|
|
188 |
for j, pn in enumerate(persons_name): |
|
|
189 |
s_time = time.time() |
|
|
190 |
patient_data = get_patient_data(pn, root_path) |
|
|
191 |
|
|
|
192 |
# (slices, h, w) |
|
|
193 |
es_pred = test_voxel(model, patient_data['ES_VOL'], used_df=used_df, resize=patient_data['size']) |
|
|
194 |
ed_pred = test_voxel(model, patient_data['ED_VOL'], used_df=used_df, resize=patient_data['size']) |
|
|
195 |
es_pred = np.transpose(es_pred, (2, 1, 0)) |
|
|
196 |
ed_pred = np.transpose(ed_pred, (2, 1, 0)) |
|
|
197 |
|
|
|
198 |
img_4D = patient_data['4D'] |
|
|
199 |
h, w, s, t = img_4D.shape |
|
|
200 |
pred_4D = np.zeros((w, h, s, t)) |
|
|
201 |
for i in range(img_4D.shape[-1]): |
|
|
202 |
pred = test_voxel(model, np.transpose(img_4D[...,i], (2, 0, 1)), used_df=used_df, resize=patient_data['size']) |
|
|
203 |
pred_4D[..., i] = np.transpose(pred, (2, 1, 0)) |
|
|
204 |
|
|
|
205 |
save_path = os.path.join(log_path, "all_predictions") |
|
|
206 |
os.makedirs(save_path, exist_ok=True) |
|
|
207 |
CheckSizeAndSaveVolume(pred_4D, patient_data, save_path) |
|
|
208 |
progress_bar(j%(len(persons_name)+1), len(persons_name),time.time()-s_time) |
|
|
209 |
|
|
|
210 |
|
|
|
211 |
def CheckSizeAndSaveVolume(seg_4D, patient_data, save_path): |
|
|
212 |
""" |
|
|
213 |
TODO: |
|
|
214 |
""" |
|
|
215 |
prefix = patient_data['pid'] |
|
|
216 |
suffix = '4D' |
|
|
217 |
|
|
|
218 |
# save_nii(seg_4D, patient_data['4D_affine'], patient_data['4D_hdr'], save_path, prefix, suffix) |
|
|
219 |
suffix = 'ED' |
|
|
220 |
ED_phase_n = int(patient_data['ED']) |
|
|
221 |
ED_pred = seg_4D[:,:,:,ED_phase_n] |
|
|
222 |
save_nii(ED_pred.astype(np.uint8), patient_data['3D_affine'], patient_data['3D_hdr'], save_path, prefix, suffix) |
|
|
223 |
|
|
|
224 |
suffix = 'ES' |
|
|
225 |
ES_phase_n = int(patient_data['ES']) |
|
|
226 |
ES_pred = seg_4D[:,:,:,ES_phase_n] |
|
|
227 |
save_nii(ES_pred.astype(np.uint8), patient_data['3D_affine'], patient_data['3D_hdr'], save_path, prefix, suffix) |
|
|
228 |
|
|
|
229 |
# ED_GT = patient_data.get('ED_GT', None) |
|
|
230 |
results = [] |
|
|
231 |
return results |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
|
|
|
235 |
if __name__ == "__main__": |
|
|
236 |
# get_person_names() |
|
|
237 |
import argparse |
|
|
238 |
parser = argparse.ArgumentParser(description="arg parser") |
|
|
239 |
parser.add_argument('--mgpus', type=str, default='0', required=False, help='whether to use multiple gpu') |
|
|
240 |
parser.add_argument('--used_df', type=str, default='True', help='whether to use df') |
|
|
241 |
parser.add_argument('--model', type=str, default='', help='whether to use df') |
|
|
242 |
parser.add_argument('--selfeat', type=bool, default=True, help='whether to use feature select') |
|
|
243 |
parser.add_argument('--model_path', type=str, default=None, help='whether to train with evaluation') |
|
|
244 |
parser.add_argument('--output_dir', type=str, default="logs/acdc_logs/logs_supcat_auxseg/predtions", required=False, help='specify an output directory if needed') |
|
|
245 |
parser.add_argument('--log_file', type=str, default="../log_predtion.txt", help="the file to write logging") |
|
|
246 |
|
|
|
247 |
args = parser.parse_args() |
|
|
248 |
if args.mgpus is not None: |
|
|
249 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus |
|
|
250 |
|
|
|
251 |
model_path = "logs/acdc_logs/logs_supcat_auxseg/ckpt/checkpoint_epoch_70.pth" |
|
|
252 |
|
|
|
253 |
test(args.mgpus, "U_NetDF", model_path, args.used_df, args.selfeat, args.output_dir) |