Diff of /test_motion.py [000000] .. [721c7a]

Switch to unified view

a b/test_motion.py
1
import torch.nn as nn
2
import numpy as np
3
4
import torch.optim as optim
5
from torch.utils.data import DataLoader
6
from torch.autograd import Variable
7
import torch.nn.functional as F
8
from tqdm import tqdm
9
import time
10
from torch.utils.tensorboard import SummaryWriter
11
import matplotlib.pyplot as plt
12
import pdb
13
import imageio
14
import os
15
import sys
16
import nibabel as nib
17
import neural_renderer as nr
18
import pyvista as pv
19
20
from network_motion import *
21
from dataio_motion import *
22
from utils import *
23
import vtk
24
import scipy.io
25
import csv
26
import pdb
27
28
n_class = 4
29
n_worker = 4
30
bs = 1
31
T_num = 50 # number of frames
32
width = 128
33
height = 128
34
depth = 64
35
temper = 3
36
sa_sliceall = [12, 17, 22, 27, 32, 37, 42, 47, 52]
37
38
39
model_save_path = './models/model_motion'
40
# pytorch only saves the last model
41
Motion_save_path = os.path.join(model_save_path, 'motionEst.pth')
42
Motion_LA_save_path = os.path.join(model_save_path, 'multiview.pth')
43
44
45
46
def test_all(sub_path):
47
    MotionNet.eval()
48
    MV_LA.eval()
49
50
51
    hd_SA = []
52
    hd_2CH = []
53
    hd_4CH = []
54
55
    bfscore_SA = []
56
    bfscore_2CH = []
57
    bfscore_4CH = []
58
59
60
    for name in glob.glob(os.path.join(sub_path, '*')):
61
62
        sub_name = name.split('/')[-1]
63
        print (sub_name)
64
65
        image_sa_ES_bank, image_2ch_ES_bank, image_4ch_ES_bank, \
66
        contour_sa_es, contour_2ch_es, contour_4ch_es, vertex_ed, faces, affine_inv, affine, origin = load_data_ES(
67
            sub_path, sub_name)
68
69
        img_sa_es = torch.from_numpy(image_sa_ES_bank[0:1, :, :, :])
70
        img_sa_ed = torch.from_numpy(image_sa_ES_bank[1:2, :, :, :])
71
        img_2ch_es = torch.from_numpy(image_2ch_ES_bank[0:1, :, :, :])
72
        img_2ch_ed = torch.from_numpy(image_2ch_ES_bank[1:2, :, :, :])
73
        img_4ch_es = torch.from_numpy(image_4ch_ES_bank[0:1, :, :, :])
74
        img_4ch_ed = torch.from_numpy(image_4ch_ES_bank[1:2, :, :, :])
75
76
77
78
        with torch.no_grad():
79
            x_sa_es = img_sa_es.type(Tensor)
80
            x_sa_ed = img_sa_ed.type(Tensor)
81
            x_2ch_es = img_2ch_es.type(Tensor)
82
            x_2ch_ed = img_2ch_ed.type(Tensor)
83
            x_4ch_es = img_4ch_es.type(Tensor)
84
            x_4ch_ed = img_4ch_ed.type(Tensor)
85
86
87
88
            aff_sa_inv = torch.from_numpy(affine_inv[0, :, :]).type(Tensor).unsqueeze(0)
89
            aff_sa = torch.from_numpy(affine[0, :, :]).type(Tensor).unsqueeze(0)
90
            aff_2ch_inv = torch.from_numpy(affine_inv[1, :, :]).type(Tensor).unsqueeze(0)
91
            aff_4ch_inv = torch.from_numpy(affine_inv[2, :, :]).type(Tensor).unsqueeze(0)
92
93
            origin_sa = torch.from_numpy(origin[0, :]).type(Tensor).unsqueeze(0)
94
            origin_2ch = torch.from_numpy(origin[1, :]).type(Tensor).unsqueeze(0)
95
            origin_4ch = torch.from_numpy(origin[2, :]).type(Tensor).unsqueeze(0)
96
97
            vertex_0 = torch.from_numpy(vertex_ed).unsqueeze(0).permute(0, 2, 1).type(
98
                Tensor)  # [bs, 3, number of vertices]
99
100
            net_la = MV_LA(x_2ch_es, x_2ch_ed, x_4ch_es, x_4ch_ed)
101
            net_sa = MotionNet(x_sa_es, x_sa_ed, net_la['conv2_2ch'], net_la['conv2s_2ch'], net_la['conv2_4ch'],
102
                               net_la['conv2s_4ch'])
103
            # ---------------sample from 3D motion fields
104
            # translate coordinate
105
            v_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], vertex_0) + aff_sa_inv[:, :3, 3:4]
106
            v_ed = v_ed_o.permute(0, 2, 1) - origin_sa  # [bs, number of vertices,3]
107
            v_ed_x = (v_ed[:, :, 0:1] - (width / 2)) / (width / 2)
108
            v_ed_y = (v_ed[:, :, 1:2] - (height / 2)) / (height / 2)
109
            v_ed_z = (v_ed[:, :, 2:3] - (depth / 2)) / (depth / 2)
110
            v_ed_norm = torch.cat((v_ed_x, v_ed_y, v_ed_z), 2)
111
            v_ed_norm_expand = v_ed_norm.unsqueeze(1).unsqueeze(1)  # [bs, 1, 1,number of vertices,3]
112
113
            # sample from 3D motion field
114
            pxx = F.grid_sample(net_sa['out'][:, 0:1], v_ed_norm_expand, align_corners=True).transpose(4, 3)
115
            pyy = F.grid_sample(net_sa['out'][:, 1:2], v_ed_norm_expand, align_corners=True).transpose(4, 3)
116
            pzz = F.grid_sample(net_sa['out'][:, 2:3], v_ed_norm_expand, align_corners=True).transpose(4, 3)
117
            delta_p = torch.cat((pxx, pyy, pzz), 4)
118
            # updata coor (image space)
119
            # print (v_ed.shape, delta_p.shape)
120
            v_es_norm_expand = v_ed_norm_expand + delta_p  # [bs, 1, 1,number of vertices,3]
121
            # t frame
122
            v_es_norm = v_es_norm_expand.squeeze(1).squeeze(1)
123
            v_es_x = v_es_norm[:, :, 0:1] * (width / 2) + (width / 2)
124
            v_es_y = v_es_norm[:, :, 1:2] * (height / 2) + (height / 2)
125
            v_es_z = v_es_norm[:, :, 2:3] * (depth / 2) + (depth / 2)
126
            v_es_crop = torch.cat((v_es_x, v_es_y, v_es_z), 2)
127
            # translate back to mesh space
128
            v_es = v_es_crop + origin_sa  # [bs, number of vertices,3]
129
            pred_vertex_es = torch.matmul(aff_sa[:, :3, :3], v_es.permute(0, 2, 1)) + aff_sa[:, :3,
130
                                                                                    3:4]  # [bs, 3, number of vertices]
131
132
133
134
            # --------------------compute segmentation evalutation
135
136
            # slicer
137
138
            # coordinate transformation np.dot(aff_sa_SR_inv[:3,:3], points_ED.T) + aff_sa_SR_inv[:3,3:4]
139
            v_sa_hat_es_o = torch.matmul(aff_sa_inv[:, :3, :3], pred_vertex_es) + aff_sa_inv[:, :3, 3:4]
140
            v_sa_hat_es = v_sa_hat_es_o.permute(0, 2, 1) - origin_sa
141
            # print (v_sa_hat_es[0,:,2])
142
            # print (v_sa_hat_t.shape)
143
            v_2ch_hat_es_o = torch.matmul(aff_2ch_inv[:, :3, :3], pred_vertex_es) + aff_2ch_inv[:, :3, 3:4]
144
            v_2ch_hat_es = v_2ch_hat_es_o.permute(0, 2, 1) - origin_2ch
145
            v_4ch_hat_es_o = torch.matmul(aff_4ch_inv[:, :3, :3], pred_vertex_es) + aff_4ch_inv[:, :3, 3:4]
146
            v_4ch_hat_es = v_4ch_hat_es_o.permute(0, 2, 1) - origin_4ch
147
148
            # project vertices satisfying threshood
149
            # project to SAX slices, project all vertices to a target plane,
150
            # vertices selection is moved to loss computation function
151
            v_sa_hat_es_x = torch.clamp(v_sa_hat_es[:, :, 0:1], min=0, max=height - 1)
152
            v_sa_hat_es_y = torch.clamp(v_sa_hat_es[:, :, 1:2], min=0, max=width - 1)
153
            v_sa_hat_es_cp = torch.cat((v_sa_hat_es_x, v_sa_hat_es_y, v_sa_hat_es[:, :, 2:3]), 2)
154
155
156
157
            mcd_sa, hd_sa = compute_sa_mcd_hd(v_sa_hat_es_cp, contour_sa_es, sa_sliceall)
158
            bfscore_sa = compute_sa_Fboundary(v_sa_hat_es_cp, contour_sa_es, sa_sliceall, height, width)
159
160
161
            # project to LAX 2CH view
162
            v_2ch_hat_es_x = torch.clamp(v_2ch_hat_es[:, :, 0:1], min=0, max=height - 1)
163
            v_2ch_hat_es_y = torch.clamp(v_2ch_hat_es[:, :, 1:2], min=0, max=width - 1)
164
            v_2ch_hat_es_cp = torch.cat((v_2ch_hat_es_x, v_2ch_hat_es_y, v_2ch_hat_es[:, :, 2:3]), 2)
165
166
            idx_2ch = slice_2D(v_2ch_hat_es_cp, 0)
167
            idx_2ch_gt = np.stack(np.nonzero(contour_2ch_es), 1)
168
            mcd_2ch, hd_2ch = distance_metric(idx_2ch, idx_2ch_gt, 1.25)
169
170
            la_2ch_pred_con = np.zeros(shape=(height, width), dtype=np.uint8)
171
            for j in range(idx_2ch.shape[0]):
172
                la_2ch_pred_con[idx_2ch[j,0], idx_2ch[j,1]] = 1
173
174
            bfscore_2ch = compute_la_Fboundary(la_2ch_pred_con, contour_2ch_es)
175
176
            # project to LAX 4CH view
177
            v_4ch_hat_es_x = torch.clamp(v_4ch_hat_es[:, :, 0:1], min=0, max=height - 1)
178
            v_4ch_hat_es_y = torch.clamp(v_4ch_hat_es[:, :, 1:2], min=0, max=width - 1)
179
            v_4ch_hat_es_cp = torch.cat((v_4ch_hat_es_x, v_4ch_hat_es_y, v_4ch_hat_es[:, :, 2:3]), 2)
180
181
182
            idx_4ch = slice_2D(v_4ch_hat_es_cp, 0)
183
            idx_4ch_gt = np.stack(np.nonzero(contour_4ch_es), 1)
184
            mcd_4ch, hd_4ch = distance_metric(idx_4ch, idx_4ch_gt, 1.25)
185
            la_4ch_pred_con = np.zeros(shape=(height, width), dtype=np.uint8)
186
            for j in range(idx_4ch.shape[0]):
187
                la_4ch_pred_con[idx_4ch[j,0], idx_4ch[j,1]] = 1
188
189
            bfscore_4ch = compute_la_Fboundary(la_4ch_pred_con, contour_4ch_es)
190
191
192
            if (hd_sa != None):
193
                hd_SA.append(hd_sa)
194
            if (hd_2ch != None):
195
                hd_2CH.append(hd_2ch)
196
            if (hd_4ch != None):
197
                hd_4CH.append(hd_4ch)
198
199
200
            if (bfscore_sa != None):
201
                bfscore_SA.append(bfscore_sa)
202
            if (bfscore_2ch != None):
203
                bfscore_2CH.append(bfscore_2ch)
204
            if (bfscore_4ch != None):
205
                bfscore_4CH.append(bfscore_4ch)
206
207
            print (hd_sa, hd_2ch, hd_4ch)
208
            print (bfscore_sa, bfscore_2ch, bfscore_4ch)
209
210
211
    print('SA HD: {:.4f}({:.4f}), 2CH HD: {:.4f}({:.4f}), 4CH HD: {:.4f}({:.4f})'
212
          .format(np.mean(hd_SA), np.std(hd_SA), np.mean(hd_2CH), np.std(hd_2CH), np.mean(hd_4CH), np.std(hd_4CH)))
213
    print('SA BFscore: {:.4f}({:.4f}), 2CH BFscore: {:.4f}({:.4f}), 4CH BFscore: {:.4f}({:.4f})'
214
          .format(np.mean(bfscore_SA), np.std(bfscore_SA), np.mean(bfscore_2CH), np.std(bfscore_2CH), np.mean(bfscore_4CH), np.std(bfscore_4CH)))
215
216
217
    return
218
219
220
221
test_data_path = '/test_data_path'
222
223
test_set = TestDataset(test_data_path)
224
testing_data_loader = DataLoader(dataset=test_set, num_workers=n_worker, batch_size=bs, shuffle=False)
225
226
MotionNet = MotionMesh_25d().cuda()
227
MV_LA = Mesh_2d().cuda()
228
229
MotionNet.load_state_dict(torch.load(Motion_save_path), strict=True)
230
MV_LA.load_state_dict(torch.load(Motion_LA_save_path), strict=True)
231
232
Tensor = torch.cuda.FloatTensor
233
TensorLong = torch.cuda.LongTensor
234
235
236
237
start = time.time()
238
test_all(test_data_path)
239
end = time.time()
240
print("testing took {:.8f}".format(end - start))