|
a |
|
b/utils/test_patch.py |
|
|
1 |
import h5py |
|
|
2 |
import math |
|
|
3 |
import nibabel as nib |
|
|
4 |
import numpy as np |
|
|
5 |
from medpy import metric |
|
|
6 |
import torch |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
from tqdm import tqdm |
|
|
9 |
from skimage.measure import label |
|
|
10 |
|
|
|
11 |
def getLargestCC(segmentation): |
|
|
12 |
labels = label(segmentation) |
|
|
13 |
assert( labels.max() != 0 ) # assume at least 1 CC |
|
|
14 |
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 |
|
|
15 |
return largestCC |
|
|
16 |
|
|
|
17 |
def var_all_case(model, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, dataset_name="LA"): |
|
|
18 |
if dataset_name == "LA": |
|
|
19 |
p = '/data/omnisky/postgraduate/Yb/data_set/LASet' |
|
|
20 |
with open(p+'/test.list', 'r') as f: |
|
|
21 |
image_list = f.readlines() |
|
|
22 |
image_list = [p+"/2018LA_Seg_Training Set/" + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list] |
|
|
23 |
elif dataset_name == "Pancreas_CT": |
|
|
24 |
with open('./data/Pancreas/test.list', 'r') as f: |
|
|
25 |
image_list = f.readlines() |
|
|
26 |
image_list = ["./data/Pancreas/Pancreas_h5/" + item.replace('\n', '') + "_norm.h5" for item in image_list] |
|
|
27 |
loader = tqdm(image_list) |
|
|
28 |
total_dice = 0.0 |
|
|
29 |
for image_path in loader: |
|
|
30 |
h5f = h5py.File(image_path, 'r') |
|
|
31 |
image = h5f['image'][:] |
|
|
32 |
label = h5f['label'][:] |
|
|
33 |
prediction, score_map = test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes) |
|
|
34 |
if np.sum(prediction)==0: |
|
|
35 |
dice = 0 |
|
|
36 |
else: |
|
|
37 |
dice = metric.binary.dc(prediction, label) |
|
|
38 |
total_dice += dice |
|
|
39 |
avg_dice = total_dice / len(image_list) |
|
|
40 |
print('average metric is {}'.format(avg_dice)) |
|
|
41 |
return avg_dice |
|
|
42 |
|
|
|
43 |
def test_all_case(model_name, num_outputs, model, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None, metric_detail=1, nms=0): |
|
|
44 |
|
|
|
45 |
loader = tqdm(image_list) if not metric_detail else image_list |
|
|
46 |
ith = 0 |
|
|
47 |
total_metric = 0.0 |
|
|
48 |
total_metric_average = 0.0 |
|
|
49 |
for image_path in loader: |
|
|
50 |
h5f = h5py.File(image_path, 'r') |
|
|
51 |
image = h5f['image'][:] |
|
|
52 |
label = h5f['label'][:] |
|
|
53 |
if preproc_fn is not None: |
|
|
54 |
image = preproc_fn(image) |
|
|
55 |
prediction, score_map = test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes) |
|
|
56 |
if num_outputs > 1: |
|
|
57 |
prediction_average, score_map_average = test_single_case_average_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes) |
|
|
58 |
if nms: |
|
|
59 |
prediction = getLargestCC(prediction) |
|
|
60 |
if num_outputs > 1: |
|
|
61 |
prediction_average = getLargestCC(prediction_average) |
|
|
62 |
|
|
|
63 |
if np.sum(prediction)==0: |
|
|
64 |
single_metric = (0,0,0,0) |
|
|
65 |
if num_outputs > 1: |
|
|
66 |
single_metric_average = (0,0,0,0) |
|
|
67 |
else: |
|
|
68 |
single_metric = calculate_metric_percase(prediction, label[:]) |
|
|
69 |
if num_outputs > 1: |
|
|
70 |
single_metric_average = calculate_metric_percase(prediction_average, label[:]) |
|
|
71 |
|
|
|
72 |
if metric_detail: |
|
|
73 |
print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3])) |
|
|
74 |
if num_outputs > 1: |
|
|
75 |
print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric_average[0], single_metric_average[1], single_metric_average[2], single_metric_average[3])) |
|
|
76 |
|
|
|
77 |
total_metric += np.asarray(single_metric) |
|
|
78 |
if num_outputs > 1: |
|
|
79 |
total_metric_average += np.asarray(single_metric_average) |
|
|
80 |
|
|
|
81 |
if save_result: |
|
|
82 |
nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + "%02d_pred.nii.gz" % ith) |
|
|
83 |
nib.save(nib.Nifti1Image(score_map[0].astype(np.float32), np.eye(4)), test_save_path + "%02d_scores.nii.gz" % ith) |
|
|
84 |
if num_outputs > 1: |
|
|
85 |
nib.save(nib.Nifti1Image(prediction_average.astype(np.float32), np.eye(4)), test_save_path + "%02d_pred_average.nii.gz" % ith) |
|
|
86 |
nib.save(nib.Nifti1Image(score_map_average[0].astype(np.float32), np.eye(4)), test_save_path + "%02d_scores_average.nii.gz" % ith) |
|
|
87 |
nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_img.nii.gz" % ith) |
|
|
88 |
nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_gt.nii.gz" % ith) |
|
|
89 |
|
|
|
90 |
ith += 1 |
|
|
91 |
|
|
|
92 |
avg_metric = total_metric / len(image_list) |
|
|
93 |
print('average metric is decoder 1 {}'.format(avg_metric)) |
|
|
94 |
if num_outputs > 1: |
|
|
95 |
avg_metric_average = total_metric_average / len(image_list) |
|
|
96 |
print('average metric of all decoders is {}'.format(avg_metric_average)) |
|
|
97 |
|
|
|
98 |
with open(test_save_path+'../{}_performance.txt'.format(model_name), 'w') as f: |
|
|
99 |
f.writelines('average metric of decoder 1 is {} \n'.format(avg_metric)) |
|
|
100 |
if num_outputs > 1: |
|
|
101 |
f.writelines('average metric of all decoders is {} \n'.format(avg_metric_average)) |
|
|
102 |
return avg_metric |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
def test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=1): |
|
|
106 |
w, h, d = image.shape |
|
|
107 |
|
|
|
108 |
# if the size of image is less than patch_size, then padding it |
|
|
109 |
add_pad = False |
|
|
110 |
if w < patch_size[0]: |
|
|
111 |
w_pad = patch_size[0]-w |
|
|
112 |
add_pad = True |
|
|
113 |
else: |
|
|
114 |
w_pad = 0 |
|
|
115 |
if h < patch_size[1]: |
|
|
116 |
h_pad = patch_size[1]-h |
|
|
117 |
add_pad = True |
|
|
118 |
else: |
|
|
119 |
h_pad = 0 |
|
|
120 |
if d < patch_size[2]: |
|
|
121 |
d_pad = patch_size[2]-d |
|
|
122 |
add_pad = True |
|
|
123 |
else: |
|
|
124 |
d_pad = 0 |
|
|
125 |
wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 |
|
|
126 |
hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 |
|
|
127 |
dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 |
|
|
128 |
if add_pad: |
|
|
129 |
image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) |
|
|
130 |
ww,hh,dd = image.shape |
|
|
131 |
|
|
|
132 |
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 |
|
|
133 |
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 |
|
|
134 |
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 |
|
|
135 |
# print("{}, {}, {}".format(sx, sy, sz)) |
|
|
136 |
score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) |
|
|
137 |
cnt = np.zeros(image.shape).astype(np.float32) |
|
|
138 |
|
|
|
139 |
for x in range(0, sx): |
|
|
140 |
xs = min(stride_xy*x, ww-patch_size[0]) |
|
|
141 |
for y in range(0, sy): |
|
|
142 |
ys = min(stride_xy * y,hh-patch_size[1]) |
|
|
143 |
for z in range(0, sz): |
|
|
144 |
zs = min(stride_z * z, dd-patch_size[2]) |
|
|
145 |
test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] |
|
|
146 |
test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) |
|
|
147 |
test_patch = torch.from_numpy(test_patch).cuda() |
|
|
148 |
|
|
|
149 |
with torch.no_grad(): |
|
|
150 |
y = model(test_patch) |
|
|
151 |
if len(y) > 1: |
|
|
152 |
y = y[0] |
|
|
153 |
y = F.softmax(y, dim=1) |
|
|
154 |
y = y.cpu().data.numpy() |
|
|
155 |
y = y[0,1,:,:,:] |
|
|
156 |
score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ |
|
|
157 |
= score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y |
|
|
158 |
cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ |
|
|
159 |
= cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 |
|
|
160 |
|
|
|
161 |
score_map = score_map/np.expand_dims(cnt,axis=0) |
|
|
162 |
label_map = (score_map[0]>0.5).astype(np.int) |
|
|
163 |
if add_pad: |
|
|
164 |
label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] |
|
|
165 |
score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] |
|
|
166 |
return label_map, score_map |
|
|
167 |
|
|
|
168 |
def test_single_case_average_output(net, image, stride_xy, stride_z, patch_size, num_classes=1): |
|
|
169 |
w, h, d = image.shape |
|
|
170 |
|
|
|
171 |
# if the size of image is less than patch_size, then padding it |
|
|
172 |
add_pad = False |
|
|
173 |
if w < patch_size[0]: |
|
|
174 |
w_pad = patch_size[0]-w |
|
|
175 |
add_pad = True |
|
|
176 |
else: |
|
|
177 |
w_pad = 0 |
|
|
178 |
if h < patch_size[1]: |
|
|
179 |
h_pad = patch_size[1]-h |
|
|
180 |
add_pad = True |
|
|
181 |
else: |
|
|
182 |
h_pad = 0 |
|
|
183 |
if d < patch_size[2]: |
|
|
184 |
d_pad = patch_size[2]-d |
|
|
185 |
add_pad = True |
|
|
186 |
else: |
|
|
187 |
d_pad = 0 |
|
|
188 |
wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 |
|
|
189 |
hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 |
|
|
190 |
dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 |
|
|
191 |
if add_pad: |
|
|
192 |
image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) |
|
|
193 |
ww,hh,dd = image.shape |
|
|
194 |
|
|
|
195 |
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 |
|
|
196 |
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 |
|
|
197 |
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 |
|
|
198 |
# print("{}, {}, {}".format(sx, sy, sz)) |
|
|
199 |
score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) |
|
|
200 |
cnt = np.zeros(image.shape).astype(np.float32) |
|
|
201 |
|
|
|
202 |
for x in range(0, sx): |
|
|
203 |
xs = min(stride_xy*x, ww-patch_size[0]) |
|
|
204 |
for y in range(0, sy): |
|
|
205 |
ys = min(stride_xy * y,hh-patch_size[1]) |
|
|
206 |
for z in range(0, sz): |
|
|
207 |
zs = min(stride_z * z, dd-patch_size[2]) |
|
|
208 |
test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] |
|
|
209 |
test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) |
|
|
210 |
test_patch = torch.from_numpy(test_patch).cuda() |
|
|
211 |
|
|
|
212 |
with torch.no_grad(): |
|
|
213 |
y_logit = net(test_patch) |
|
|
214 |
num_outputs = len(y_logit) |
|
|
215 |
y=torch.zeros(y_logit[0].shape).cuda() |
|
|
216 |
for idx in range(num_outputs): |
|
|
217 |
y += y_logit[idx] |
|
|
218 |
y/=num_outputs |
|
|
219 |
|
|
|
220 |
y = y.cpu().data.numpy() |
|
|
221 |
y = y[0,1,:,:,:] |
|
|
222 |
score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ |
|
|
223 |
= score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y |
|
|
224 |
cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ |
|
|
225 |
= cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 |
|
|
226 |
|
|
|
227 |
score_map = score_map/np.expand_dims(cnt,axis=0) |
|
|
228 |
label_map = (score_map[0]>0.5).astype(np.int) |
|
|
229 |
if add_pad: |
|
|
230 |
label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] |
|
|
231 |
score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] |
|
|
232 |
return label_map, score_map |
|
|
233 |
|
|
|
234 |
def calculate_metric_percase(pred, gt): |
|
|
235 |
dice = metric.binary.dc(pred, gt) |
|
|
236 |
jc = metric.binary.jc(pred, gt) |
|
|
237 |
hd = metric.binary.hd95(pred, gt) |
|
|
238 |
asd = metric.binary.asd(pred, gt) |
|
|
239 |
|
|
|
240 |
return dice, jc, hd, asd |