|
a |
|
b/utils/evalF.py |
|
|
1 |
import sys |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
import evalMetrics as METRICS |
|
|
5 |
import PP |
|
|
6 |
import numpy as np |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
from torch.autograd import Variable |
|
|
10 |
|
|
|
11 |
import augmentations as AUG |
|
|
12 |
|
|
|
13 |
#--------------------------------------------- |
|
|
14 |
#Evaluation functions |
|
|
15 |
#--------------------------------------------- |
|
|
16 |
|
|
|
17 |
def evalModelX(model, num_labels, postfix, main_folder_path, eval_method, gpu0, useGPU, |
|
|
18 |
patch_size = 70, eval_metric = 'iou', test_augm = False, extra_patch = 30): |
|
|
19 |
eval_list = main_folder_path + 'val' + postfix + '.txt' |
|
|
20 |
img_list = open(eval_list).readlines() |
|
|
21 |
v = 0 |
|
|
22 |
v_priv = 0 |
|
|
23 |
for img_str in img_list: |
|
|
24 |
img_str = img_str.rstrip() |
|
|
25 |
_, gt, out, _ = predict(os.path.join(main_folder_path, img_str), model, num_labels, postfix, |
|
|
26 |
main_folder_path, eval_method, gpu0, useGPU, patch_size=patch_size, |
|
|
27 |
test_augm = test_augm, extra_patch = extra_patch) |
|
|
28 |
curr_eval = METRICS.metricEval(eval_metric, out, gt, num_labels) |
|
|
29 |
v+=curr_eval |
|
|
30 |
return v / len(img_list) |
|
|
31 |
|
|
|
32 |
def testPredict(img, model, num_labels, eval_method, gpu0, useGPU, stride= 50, patch_size = 70, test_augm = True, extra_patch = 30, get_soft = False): |
|
|
33 |
if eval_method == 0: |
|
|
34 |
if useGPU: |
|
|
35 |
out = model(Variable(torch.from_numpy(img).float(),volatile = True).cuda(gpu0)) |
|
|
36 |
else: |
|
|
37 |
out = model(Variable(torch.from_numpy(img).float(),volatile = True)) |
|
|
38 |
out = out.data[0].cpu().numpy() |
|
|
39 |
elif eval_method == 1: |
|
|
40 |
out = predictByPatches(img, model, num_labels, useGPU, gpu0, |
|
|
41 |
stride = stride, patch_size = patch_size, |
|
|
42 |
test_augm = test_augm, extra_patch = extra_patch) |
|
|
43 |
out = out.squeeze() |
|
|
44 |
if get_soft: |
|
|
45 |
return out |
|
|
46 |
#take argmax to get predictions |
|
|
47 |
out = np.argmax(out, axis = 0) |
|
|
48 |
#remove batch and label dimension |
|
|
49 |
out = out.squeeze() |
|
|
50 |
return out |
|
|
51 |
|
|
|
52 |
#returns the image as numpy, the ground truth and the prediction given model and input path |
|
|
53 |
#affine = True, returns the affine transformation from loading the scan |
|
|
54 |
def predict(img_path, model, num_labels, postfix, main_folder_path, eval_method, gpu0, useGPU, |
|
|
55 |
stride = 50, patch_size = 70, test_augm = True, extra_patch = 30): |
|
|
56 |
#read image |
|
|
57 |
img = PP.numpyFromScan(img_path) |
|
|
58 |
#read wmh |
|
|
59 |
gt_path = img_path.replace('slices', 'gt_slices').replace('FLAIR', 'wmh').replace('/pre','') |
|
|
60 |
gt, affine = PP.numpyFromScan(gt_path, get_affine = True, makebin = (num_labels == 2)) |
|
|
61 |
|
|
|
62 |
img = img.transpose((3,0,1,2)) |
|
|
63 |
img = img[np.newaxis, :] |
|
|
64 |
gt = gt.transpose((3,0,1,2)) |
|
|
65 |
|
|
|
66 |
if eval_method == 0: |
|
|
67 |
if useGPU: |
|
|
68 |
out_v = model(Variable(torch.from_numpy(img).float(),volatile = True).cuda(gpu0)) |
|
|
69 |
else: |
|
|
70 |
out_v = model(Variable(torch.from_numpy(img).float(),volatile = True)) |
|
|
71 |
out = out_v.data[0].cpu().numpy() |
|
|
72 |
#FIX? |
|
|
73 |
del out_v |
|
|
74 |
out_v = Variable(torch.from_numpy(np.array([1])).float()) |
|
|
75 |
out_v = Variable(torch.from_numpy(np.array([1])).float()) |
|
|
76 |
elif eval_method == 1: |
|
|
77 |
out = predictByPatches(img, model, num_labels, useGPU, gpu0, stride = stride, patch_size = patch_size, test_augm = test_augm, extra_patch = extra_patch) |
|
|
78 |
out = out.squeeze() |
|
|
79 |
#take argmax to get predictions |
|
|
80 |
out = np.argmax(out, axis = 0) |
|
|
81 |
#remove batch and label dimension |
|
|
82 |
img = img.squeeze() |
|
|
83 |
out = out.squeeze() |
|
|
84 |
gt = gt.squeeze() |
|
|
85 |
|
|
|
86 |
return img, gt, out, affine |
|
|
87 |
|
|
|
88 |
def predictByPatches(img, model, num_labels, useGPU, gpu0, patch_size = 70, test_augm = False, stride = 50, extra_pad = 0, extra_patch = 30): |
|
|
89 |
batch_num, num_channels, dim1, dim2, dim3 = img.shape |
|
|
90 |
p_size = patch_size |
|
|
91 |
#add padding to each dim s.t. % stride = 0 |
|
|
92 |
dim1_pad = (stride - ((dim1-p_size) % stride)) % stride |
|
|
93 |
dim2_pad = (stride - ((dim2-p_size) % stride)) % stride |
|
|
94 |
dim3_pad = (stride - ((dim3-p_size) % stride)) % stride |
|
|
95 |
|
|
|
96 |
x_1_off, x_2_off = int(round(dim1_pad/2.0)), dim1_pad//2 |
|
|
97 |
y_1_off, y_2_off = int(round(dim2_pad/2.0)), dim2_pad//2 |
|
|
98 |
z_1_off, z_2_off = int(round(dim3_pad/2.0)), dim3_pad//2 |
|
|
99 |
|
|
|
100 |
img = np.lib.pad(img, ((0,0),(0,0), (x_1_off, x_2_off), (y_1_off, y_2_off), (z_1_off, z_2_off)), mode='minimum') |
|
|
101 |
_, _, padded_dim1, padded_dim2, padded_dim3 = img.shape |
|
|
102 |
|
|
|
103 |
out_shape = (img.shape[0], num_labels, img.shape[2], img.shape[3], img.shape[4]) |
|
|
104 |
out_total = np.zeros(out_shape) |
|
|
105 |
out_counter = np.zeros(out_shape) |
|
|
106 |
|
|
|
107 |
extra_p = extra_patch / 2 |
|
|
108 |
for i in range(0, padded_dim1 - p_size + 1, stride): |
|
|
109 |
for j in range(0, padded_dim2 - p_size + 1, stride): |
|
|
110 |
for k in range(0, padded_dim3 - p_size + 1, stride): |
|
|
111 |
if extra_p != 0: |
|
|
112 |
i_l, i_r = getExtraPatchOffsets(i, 0, padded_dim1 - p_size, extra_p) |
|
|
113 |
j_l, j_r = getExtraPatchOffsets(j, 0, padded_dim2 - p_size, extra_p) |
|
|
114 |
k_l, k_r = getExtraPatchOffsets(k, 0, padded_dim3 - p_size, extra_p) |
|
|
115 |
|
|
|
116 |
img_patch = img[:,:, (i-i_l):(i+p_size+i_r),(j-j_l):(j+p_size+j_r),(k-k_l):(k+p_size+k_r)] |
|
|
117 |
out_np = getPatchPrediction(img_patch, model, useGPU, gpu0, extra_pad = extra_pad, test_augm = test_augm) |
|
|
118 |
out_np = removePatchOffset(out_np, i_l, i_r, j_l, j_r, k_l, k_r) |
|
|
119 |
out_total[:,:, i:i+p_size,j:j+p_size,k:k+p_size] += out_np |
|
|
120 |
out_counter[:, :, i:i+p_size, j:j+p_size, k:k+p_size] += 1 |
|
|
121 |
else: |
|
|
122 |
img_patch = img[:, :, i:i+p_size, j:j+p_size, k:k+p_size] |
|
|
123 |
#make a prediction on this image patch, adding extra padding during prediction and augmenting |
|
|
124 |
#the result is of the same shape and size as the original img patch |
|
|
125 |
out_np = getPatchPrediction(img_patch, model, useGPU, gpu0, extra_pad = extra_pad, test_augm = test_augm) |
|
|
126 |
|
|
|
127 |
out_total[:, :, i:i+p_size, j:j+p_size, k:k+p_size] += out_np |
|
|
128 |
out_counter[:, :, i:i+p_size, j:j+p_size, k:k+p_size] += 1 |
|
|
129 |
out_total = out_total / out_counter |
|
|
130 |
#remove padding from predictions |
|
|
131 |
nb, c, i_size, j_size, k_size = out_total.shape |
|
|
132 |
out_total = out_total[:, :, x_1_off:i_size-x_2_off, y_1_off:j_size-y_2_off, z_1_off:k_size-z_2_off] |
|
|
133 |
|
|
|
134 |
return out_total |
|
|
135 |
|
|
|
136 |
def getExtraPatchOffsets(v, low_bound, upper_bound, extra_p): |
|
|
137 |
v_left = 0 |
|
|
138 |
v_right = 0 |
|
|
139 |
if v - extra_p > low_bound: |
|
|
140 |
v_left = extra_p |
|
|
141 |
if v + extra_p < upper_bound: |
|
|
142 |
v_right = extra_p |
|
|
143 |
return v_left, v_right |
|
|
144 |
|
|
|
145 |
#list of tuple [(i_l, i_r), (j_l, j_r)] |
|
|
146 |
def removePatchOffset(np_arr, i_l, i_r, j_l, j_r, k_l, k_r): |
|
|
147 |
bn, c, s_i, s_j, s_k = np_arr.shape |
|
|
148 |
return np_arr[:,:,(i_l):(s_i-i_r), (j_l):(s_j-j_r), (k_l):(s_k-k_r)] |
|
|
149 |
|
|
|
150 |
def getPatchPrediction(img_patch, model, useGPU, gpu0, extra_pad = 0, test_augm = False): |
|
|
151 |
pd = extra_pad/2 |
|
|
152 |
padding = ((0,0), (0,0), (pd, pd), (pd, pd), (pd,pd)) |
|
|
153 |
img_patch = np.pad(img_patch, padding, 'constant') |
|
|
154 |
|
|
|
155 |
num_augm = 1 |
|
|
156 |
if test_augm: |
|
|
157 |
num_augm = 3 |
|
|
158 |
|
|
|
159 |
out_np_total = None |
|
|
160 |
for i in range(num_augm): |
|
|
161 |
img_patch_cp = np.copy(img_patch) |
|
|
162 |
#AUGMENT IMAGE |
|
|
163 |
if test_augm and i != 0: |
|
|
164 |
pass |
|
|
165 |
#apply augmentation |
|
|
166 |
rot_x, rot_y, rot_z = AUG.getRotationVal([10,10,10]) |
|
|
167 |
zoom_val = AUG.getScalingVal(0.8, 1.1) |
|
|
168 |
|
|
|
169 |
img_patch_cp = AUG.applyScale([img_patch_cp], zoom_val, [3])[0] |
|
|
170 |
img_patch_cp = AUG.applyRotation([img_patch_cp], [rot_x, rot_y, rot_z], [3])[0] |
|
|
171 |
|
|
|
172 |
#MAKE PREDICTION |
|
|
173 |
if useGPU: |
|
|
174 |
out = model(Variable(torch.from_numpy(img_patch_cp).float(),volatile = True).cuda(gpu0)) |
|
|
175 |
else: |
|
|
176 |
out = model(Variable(torch.from_numpy(img_patch_cp).float(),volatile = True)) |
|
|
177 |
out_np = out.data[0].cpu().numpy() |
|
|
178 |
#output is (1 x 3 x dim1 x dim2 x dim3) |
|
|
179 |
out_np = out_np[np.newaxis,:] |
|
|
180 |
if test_augm and i != 0: |
|
|
181 |
temp = np.copy(out_np) |
|
|
182 |
out_np = None |
|
|
183 |
#reverse augmentation on predictions |
|
|
184 |
rev_zoom_i = float(img_patch.shape[2]) / img_patch_cp.shape[2] |
|
|
185 |
rev_zoom_j = float(img_patch.shape[3]) / img_patch_cp.shape[3] |
|
|
186 |
rev_zoom_k = float(img_patch.shape[4]) / img_patch_cp.shape[4] |
|
|
187 |
|
|
|
188 |
for j in range(temp.shape[1]): |
|
|
189 |
r = AUG.applyRotation([temp[:,j:j+1,:,:,:]], [-rot_x, -rot_y, -rot_z], [3])[0] |
|
|
190 |
r = AUG.applyScale(r, [rev_zoom_i,rev_zoom_j,rev_zoom_k], [3])[0] |
|
|
191 |
|
|
|
192 |
if not isinstance(out_np, np.ndarray): |
|
|
193 |
out_np = np.zeros([1, temp.shape[1], r.shape[2], r.shape[3], r.shape[4]]) |
|
|
194 |
out_np[:, j,:,:,:] = r |
|
|
195 |
out_np = numpySoftmax(out_np, 1) |
|
|
196 |
if not isinstance(out_np_total, np.ndarray): |
|
|
197 |
if pd == 0: |
|
|
198 |
out_np_total = out_np |
|
|
199 |
else: |
|
|
200 |
out_np_total = out_np[:,:,pd:-pd, pd:-pd, pd:-pd] |
|
|
201 |
else: |
|
|
202 |
if pd ==0: |
|
|
203 |
out_np_total += out_np |
|
|
204 |
else: |
|
|
205 |
out_np_total += out_np[:,:,pd:-pd, pd:-pd, pd:-pd] |
|
|
206 |
|
|
|
207 |
return out_np_total / num_augm |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
|
|
|
211 |
def numpySoftmax(x, axis_): |
|
|
212 |
e_x = np.exp(x - np.max(x)) |
|
|
213 |
return e_x / (e_x.sum(axis=axis_) + 0.00001) |