Diff of /inference.py [000000] .. [903821]

Switch to unified view

a b/inference.py
1
import math
2
import torch
3
import torch.nn.functional as F
4
import numpy as np
5
import h5py
6
import nibabel as nib
7
from medpy import metric
8
from networks.vnet import VNet
9
10
11
def calculate_metric_percase(pred, gt):
12
    dice = metric.binary.dc(pred, gt)
13
    jc = metric.binary.jc(pred, gt)
14
    hd = metric.binary.hd95(pred, gt)
15
    asd = metric.binary.asd(pred, gt)
16
17
    return dice, jc, hd, asd
18
19
20
def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1):
21
    w, h, d = image.shape
22
23
    # if the size of image is less than patch_size, then padding it
24
    add_pad = False
25
    if w < patch_size[0]:
26
        w_pad = patch_size[0]-w
27
        add_pad = True
28
    else:
29
        w_pad = 0
30
    if h < patch_size[1]:
31
        h_pad = patch_size[1]-h
32
        add_pad = True
33
    else:
34
        h_pad = 0
35
    if d < patch_size[2]:
36
        d_pad = patch_size[2]-d
37
        add_pad = True
38
    else:
39
        d_pad = 0
40
    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
41
    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
42
    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
43
    if add_pad:
44
        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
45
    ww,hh,dd = image.shape
46
47
    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
48
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
49
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
50
    # print("{}, {}, {}".format(sx, sy, sz))
51
    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
52
    cnt = np.zeros(image.shape).astype(np.float32)
53
54
    for x in range(0, sx):
55
        xs = min(stride_xy*x, ww-patch_size[0])
56
        for y in range(0, sy):
57
            ys = min(stride_xy * y,hh-patch_size[1])
58
            for z in range(0, sz):
59
                zs = min(stride_z * z, dd-patch_size[2])
60
                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
61
                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
62
                test_patch = torch.from_numpy(test_patch).cuda()
63
                y1 = net(test_patch)
64
                y = F.softmax(y1, dim=1)
65
                y = y.cpu().data.numpy()
66
                y = y[0,:,:,:,:]
67
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
68
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
69
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
70
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
71
    score_map = score_map/np.expand_dims(cnt,axis=0)
72
    label_map = np.argmax(score_map, axis = 0)
73
    if add_pad:
74
        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
75
        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
76
    return label_map, score_map
77
78
def test_all_case(net, image_list, num_classes=2, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None):
79
    total_metric = 0.0
80
    for ith,image_path in enumerate(image_list):
81
        h5f = h5py.File(image_path, 'r')
82
        image = h5f['image'][:]
83
        label = h5f['label'][:]
84
        if preproc_fn is not None:
85
            image = preproc_fn(image)
86
        prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
87
88
        if np.sum(prediction)==0:
89
            single_metric = (0,0,0,0)
90
        else:
91
            single_metric = calculate_metric_percase(prediction, label[:])
92
        print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3]))
93
        total_metric += np.asarray(single_metric)
94
95
        if save_result:
96
            nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + "%02d_pred.nii.gz"%(ith))
97
            nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_img.nii.gz"%(ith))
98
            nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + "%02d_gt.nii.gz"%(ith))
99
    avg_metric = total_metric / len(image_list)
100
    print('average metric is {}'.format(avg_metric))
101
102
    return avg_metric
103
104
105
if __name__ == '__main__':
106
    data_path = '/***/LASet/data/'
107
    test_save_path = 'predictions/supervised'
108
    save_mode_path = 'model/LA_vnet_25_labeled/supervised/supervised_best_model.pth'
109
    net = VNet(n_channels=1,n_classes=2, normalization='batchnorm').cuda()
110
    net.load_state_dict(torch.load(save_mode_path))
111
    print("init weight from {}".format(save_mode_path))
112
    net.eval()
113
    with open(data_path + '/../test.list', 'r') as f:
114
        image_list = f.readlines()
115
    image_list = [data_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list]
116
    # 滑动窗口法
117
    avg_metric = test_all_case(net, image_list, num_classes=2,
118
                                patch_size=(112, 112, 80), stride_xy=18, stride_z=4,
119
                                save_result=False,test_save_path=test_save_path)