Diff of /utils/test_patch.py [000000] .. [903821]

Switch to unified view

a b/utils/test_patch.py
1
import h5py
2
import math
3
import nibabel as nib
4
import numpy as np
5
from medpy import metric
6
import torch
7
import torch.nn.functional as F
8
from tqdm import tqdm
9
from skimage.measure import label
10
11
def getLargestCC(segmentation):
12
    labels = label(segmentation)
13
    assert( labels.max() != 0 ) # assume at least 1 CC
14
    largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
15
    return largestCC
16
17
def var_all_case(model, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, dataset_name="LA"):
18
    if dataset_name == "LA":
19
        p = '/data/omnisky/postgraduate/Yb/data_set/LASet'
20
        with open(p+'/test.list', 'r') as f:
21
            image_list = f.readlines()
22
        image_list = [p+"/2018LA_Seg_Training Set/" + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list]
23
    elif dataset_name == "Pancreas_CT":
24
        with open('./data/Pancreas/test.list', 'r') as f:
25
            image_list = f.readlines()
26
        image_list = ["./data/Pancreas/Pancreas_h5/" + item.replace('\n', '') + "_norm.h5" for item in image_list]
27
    loader = tqdm(image_list)
28
    total_dice = 0.0
29
    for image_path in loader:
30
        h5f = h5py.File(image_path, 'r')
31
        image = h5f['image'][:]
32
        label = h5f['label'][:]
33
        prediction, score_map = test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
34
        if np.sum(prediction)==0:
35
            dice = 0
36
        else:
37
            dice = metric.binary.dc(prediction, label)
38
        total_dice += dice
39
    avg_dice = total_dice / len(image_list)
40
    print('average metric is {}'.format(avg_dice))
41
    return avg_dice
42
43
def test_all_case(model_name, num_outputs, model, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None, metric_detail=1, nms=0):
44
45
    loader = tqdm(image_list) if not metric_detail else image_list
46
    ith = 0
47
    total_metric = 0.0
48
    total_metric_average = 0.0
49
    for image_path in loader:
50
        h5f = h5py.File(image_path, 'r')
51
        image = h5f['image'][:]
52
        label = h5f['label'][:]
53
        if preproc_fn is not None:
54
            image = preproc_fn(image)
55
        prediction, score_map = test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
56
        if num_outputs > 1:
57
            prediction_average, score_map_average = test_single_case_average_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
58
        if nms:
59
            prediction = getLargestCC(prediction)
60
            if num_outputs > 1:
61
                 prediction_average = getLargestCC(prediction_average)
62
            
63
        if np.sum(prediction)==0:
64
            single_metric = (0,0,0,0)
65
            if num_outputs > 1:
66
                single_metric_average = (0,0,0,0)
67
        else:
68
            single_metric = calculate_metric_percase(prediction, label[:])
69
            if num_outputs > 1:
70
                single_metric_average  = calculate_metric_percase(prediction_average, label[:])
71
            
72
        if metric_detail:
73
            print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3]))
74
            if num_outputs > 1:
75
                print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric_average[0], single_metric_average[1], single_metric_average[2], single_metric_average[3]))
76
        
77
        total_metric += np.asarray(single_metric)
78
        if num_outputs > 1:
79
            total_metric_average += np.asarray(single_metric_average) 
80
        
81
        if save_result:
82
            nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path +  "%02d_pred.nii.gz" % ith)
83
            nib.save(nib.Nifti1Image(score_map[0].astype(np.float32), np.eye(4)), test_save_path +  "%02d_scores.nii.gz" % ith)
84
            if num_outputs > 1:
85
                nib.save(nib.Nifti1Image(prediction_average.astype(np.float32), np.eye(4)), test_save_path +  "%02d_pred_average.nii.gz" % ith)
86
                nib.save(nib.Nifti1Image(score_map_average[0].astype(np.float32), np.eye(4)), test_save_path +  "%02d_scores_average.nii.gz" % ith)
87
            nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path +  "%02d_img.nii.gz" % ith)
88
            nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path +  "%02d_gt.nii.gz" % ith)
89
        
90
        ith += 1
91
92
    avg_metric = total_metric / len(image_list)
93
    print('average metric is decoder 1 {}'.format(avg_metric))
94
    if num_outputs > 1:
95
        avg_metric_average = total_metric_average / len(image_list)
96
        print('average metric of all decoders is {}'.format(avg_metric_average))
97
    
98
    with open(test_save_path+'../{}_performance.txt'.format(model_name), 'w') as f:
99
        f.writelines('average metric of decoder 1 is {} \n'.format(avg_metric))
100
        if num_outputs > 1:
101
            f.writelines('average metric of all decoders is {} \n'.format(avg_metric_average))
102
    return avg_metric
103
104
105
def test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=1):
106
    w, h, d = image.shape
107
108
    # if the size of image is less than patch_size, then padding it
109
    add_pad = False
110
    if w < patch_size[0]:
111
        w_pad = patch_size[0]-w
112
        add_pad = True
113
    else:
114
        w_pad = 0
115
    if h < patch_size[1]:
116
        h_pad = patch_size[1]-h
117
        add_pad = True
118
    else:
119
        h_pad = 0
120
    if d < patch_size[2]:
121
        d_pad = patch_size[2]-d
122
        add_pad = True
123
    else:
124
        d_pad = 0
125
    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
126
    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
127
    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
128
    if add_pad:
129
        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
130
    ww,hh,dd = image.shape
131
132
    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
133
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
134
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
135
    # print("{}, {}, {}".format(sx, sy, sz))
136
    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
137
    cnt = np.zeros(image.shape).astype(np.float32)
138
139
    for x in range(0, sx):
140
        xs = min(stride_xy*x, ww-patch_size[0])
141
        for y in range(0, sy):
142
            ys = min(stride_xy * y,hh-patch_size[1])
143
            for z in range(0, sz):
144
                zs = min(stride_z * z, dd-patch_size[2])
145
                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
146
                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
147
                test_patch = torch.from_numpy(test_patch).cuda()
148
149
                with torch.no_grad():
150
                    y = model(test_patch)
151
                    if len(y) > 1:
152
                        y = y[0]
153
                    y = F.softmax(y, dim=1)
154
                y = y.cpu().data.numpy()
155
                y = y[0,1,:,:,:]
156
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
157
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
158
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
159
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
160
161
    score_map = score_map/np.expand_dims(cnt,axis=0)
162
    label_map = (score_map[0]>0.5).astype(np.int)
163
    if add_pad:
164
        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
165
        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
166
    return label_map, score_map
167
168
def test_single_case_average_output(net, image, stride_xy, stride_z, patch_size, num_classes=1):
169
    w, h, d = image.shape
170
171
    # if the size of image is less than patch_size, then padding it
172
    add_pad = False
173
    if w < patch_size[0]:
174
        w_pad = patch_size[0]-w
175
        add_pad = True
176
    else:
177
        w_pad = 0
178
    if h < patch_size[1]:
179
        h_pad = patch_size[1]-h
180
        add_pad = True
181
    else:
182
        h_pad = 0
183
    if d < patch_size[2]:
184
        d_pad = patch_size[2]-d
185
        add_pad = True
186
    else:
187
        d_pad = 0
188
    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
189
    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
190
    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
191
    if add_pad:
192
        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
193
    ww,hh,dd = image.shape
194
195
    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
196
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
197
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
198
    # print("{}, {}, {}".format(sx, sy, sz))
199
    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
200
    cnt = np.zeros(image.shape).astype(np.float32)
201
202
    for x in range(0, sx):
203
        xs = min(stride_xy*x, ww-patch_size[0])
204
        for y in range(0, sy):
205
            ys = min(stride_xy * y,hh-patch_size[1])
206
            for z in range(0, sz):
207
                zs = min(stride_z * z, dd-patch_size[2])
208
                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
209
                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
210
                test_patch = torch.from_numpy(test_patch).cuda()
211
212
                with torch.no_grad():
213
                    y_logit = net(test_patch)
214
                    num_outputs = len(y_logit)
215
                    y=torch.zeros(y_logit[0].shape).cuda()
216
                    for idx in range(num_outputs):
217
                        y += y_logit[idx]
218
                    y/=num_outputs
219
                    
220
                y = y.cpu().data.numpy()
221
                y = y[0,1,:,:,:]
222
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
223
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
224
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
225
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
226
227
    score_map = score_map/np.expand_dims(cnt,axis=0)
228
    label_map = (score_map[0]>0.5).astype(np.int)
229
    if add_pad:
230
        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
231
        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
232
    return label_map, score_map
233
234
def calculate_metric_percase(pred, gt):
235
    dice = metric.binary.dc(pred, gt)
236
    jc = metric.binary.jc(pred, gt)
237
    hd = metric.binary.hd95(pred, gt)
238
    asd = metric.binary.asd(pred, gt)
239
240
    return dice, jc, hd, asd