Diff of /test_reconstruction.py [000000] .. [721c7a]

Switch to unified view

a b/test_reconstruction.py
1
import torch.nn as nn
2
import numpy as np
3
4
import torch.optim as optim
5
from torch.utils.data import DataLoader
6
from torch.autograd import Variable
7
import torch.nn.functional as F
8
from tqdm import tqdm
9
import time
10
from torch.utils.tensorboard import SummaryWriter
11
import matplotlib.pyplot as plt
12
import pdb
13
import imageio
14
import os
15
import sys
16
import nibabel as nib
17
import neural_renderer as nr
18
import pyvista as pv
19
20
from network_reconstruction import *
21
from dataio_reconstruction import *
22
from utils import *
23
import vtk
24
import scipy.io
25
import csv
26
import pdb
27
28
n_class = 4
29
n_worker = 4
30
bs = 1
31
T_num = 50 # number of frames
32
width = 128
33
height = 128
34
depth = 64
35
temper = 2
36
sa_sliceall = [12, 17, 22, 27, 32, 37, 42, 47, 52]
37
38
39
model_save_path = './models/model_reconstruction'
40
# pytorch only saves the last model
41
Deform_save_path = os.path.join(model_save_path, 'deform.pth')
42
Motion_LA_save_path = os.path.join(model_save_path, 'multiview.pth')
43
44
def test(sub_path):
45
    DeformNet.eval()
46
    MV_LA.eval()
47
48
    hd_SA = []
49
    hd_2CH = []
50
    hd_4CH = []
51
52
    bfscore_SA = []
53
    bfscore_2CH = []
54
    bfscore_4CH = []
55
56
57
    for name in glob.glob(os.path.join(sub_path, '*')):
58
59
        sub_name = name.split('/')[-1]
60
        print (sub_name)
61
62
        image_sa_bank, image_2ch_bank, image_4ch_bank, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
63
        vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = load_data(
64
        sub_path, sub_name, T_num, rand_frame=0)
65
66
        img_sa_ed = torch.from_numpy(image_sa_bank[1:2, :, :, :])
67
        img_2ch_t = torch.from_numpy(image_2ch_bank[0:1, :, :, :])
68
        img_2ch_ed = torch.from_numpy(image_2ch_bank[1:2, :, :, :])
69
        img_4ch_t = torch.from_numpy(image_4ch_bank[0:1, :, :, :])
70
        img_4ch_ed = torch.from_numpy(image_4ch_bank[1:2, :, :, :])
71
72
        with torch.no_grad():
73
74
            x_sa_ed = img_sa_ed.type(Tensor)
75
            x_2ch_t = img_2ch_t.type(Tensor)
76
            x_2ch_ed = img_2ch_ed.type(Tensor)
77
            x_4ch_t = img_4ch_t.type(Tensor)
78
            x_4ch_ed = img_4ch_ed.type(Tensor)
79
80
81
            aff_sa_inv = torch.from_numpy(affine_inv[0, :, :]).type(Tensor).unsqueeze(0)
82
            aff_sa = torch.from_numpy(affine[0, :, :]).type(Tensor).unsqueeze(0)
83
            aff_2ch_inv = torch.from_numpy(affine_inv[1, :, :]).type(Tensor).unsqueeze(0)
84
            aff_4ch_inv = torch.from_numpy(affine_inv[2, :, :]).type(Tensor).unsqueeze(0)
85
86
            origin_sa = torch.from_numpy(origin[0:1, :]).type(Tensor)
87
            origin_2ch = torch.from_numpy(origin[1:2, :]).type(Tensor)
88
            origin_4ch = torch.from_numpy(origin[2:3, :]).type(Tensor)
89
90
            vertex_tpl_0 = torch.from_numpy(vertex_tpl_ed).unsqueeze(0).permute(0, 2, 1).type(Tensor)  # [bs, 3, number of vertices]
91
92
93
94
            net_la = MV_LA(x_2ch_t, x_2ch_ed, x_4ch_t, x_4ch_ed)
95
            net_df = DeformNet(x_sa_ed, net_la['conv2s_2ch'], net_la['conv2s_4ch'])
96
97
            # ---------------sample from 3D motion fields
98
            # translate coordinate
99
            v_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], vertex_tpl_0) + aff_sa_inv[:, :3, 3:4]
100
            v_ed = v_ed_o.permute(0, 2, 1) - origin_sa  # [bs, number of vertices,3]
101
            # normalize translated coordinate (image space) to [-1,1]
102
            v_ed_x = (v_ed[:, :, 0:1] - (width / 2)) / (width / 2)
103
            v_ed_y = (v_ed[:, :, 1:2] - (height / 2)) / (height / 2)
104
            v_ed_z = (v_ed[:, :, 2:3] - (depth / 2)) / (depth / 2)
105
            v_ed_norm = torch.cat((v_ed_x, v_ed_y, v_ed_z), 2)
106
            v_ed_norm_expand = v_ed_norm.unsqueeze(1).unsqueeze(1)  # [bs, 1, 1,number of vertices,3]
107
108
            # sample from 3D motion field
109
            pxx = F.grid_sample(net_df['out_def_ed'][:, 0:1], v_ed_norm_expand, align_corners=True).transpose(4, 3)
110
            pyy = F.grid_sample(net_df['out_def_ed'][:, 1:2], v_ed_norm_expand, align_corners=True).transpose(4, 3)
111
            pzz = F.grid_sample(net_df['out_def_ed'][:, 2:3], v_ed_norm_expand, align_corners=True).transpose(4, 3)
112
            delta_p = torch.cat((pxx, pyy, pzz), 4)
113
            # updata coor (image space)
114
            # print (v_ed.shape, delta_p.shape)
115
            v_0_norm_expand = v_ed_norm_expand + delta_p  # [bs, 1, 1,number of vertices,3]
116
            # t frame
117
            v_0_norm = v_0_norm_expand.squeeze(1).squeeze(1)
118
            v_0_x = v_0_norm[:, :, 0:1] * (width / 2) + (width / 2)
119
            v_0_y = v_0_norm[:, :, 1:2] * (height / 2) + (height / 2)
120
            v_0_z = v_0_norm[:, :, 2:3] * (depth / 2) + (depth / 2)
121
            v_0_crop = torch.cat((v_0_x, v_0_y, v_0_z), 2)
122
            # translate back to mesh space
123
            v_0 = v_0_crop + origin_sa  # [bs, number of vertices,3]
124
            pred_v_0 = torch.matmul(aff_sa[:, :3, :3], v_0.permute(0, 2, 1)) + aff_sa[:, :3,
125
                                                                               3:4]  # [bs, 3, number of vertices]
126
127
            # -------------- differentialable slicer
128
129
            # coordinate transformation np.dot(aff_sa_SR_inv[:3,:3], points_ED.T) + aff_sa_SR_inv[:3,3:4]
130
            v_sa_hat_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], pred_v_0) + aff_sa_inv[:, :3, 3:4]
131
            v_sa_hat_ed = v_sa_hat_ed_o.permute(0, 2, 1) - origin_sa
132
            # print (v_sa_hat_t.shape)
133
            v_2ch_hat_ed_o = torch.matmul(aff_2ch_inv[:, :3, :3], pred_v_0) + aff_2ch_inv[:, :3, 3:4]
134
            v_2ch_hat_ed = v_2ch_hat_ed_o.permute(0, 2, 1) - origin_2ch
135
            v_4ch_hat_ed_o = torch.matmul(aff_4ch_inv[:, :3, :3], pred_v_0) + aff_4ch_inv[:, :3, 3:4]
136
            v_4ch_hat_ed = v_4ch_hat_ed_o.permute(0, 2, 1) - origin_4ch
137
138
            # project vertices satisfying threshood
139
            # project to SAX slices, project all vertices to a target plane,
140
            # vertices selection is moved to loss computation function
141
            v_sa_hat_ed_x = torch.clamp(v_sa_hat_ed[:, :, 0:1], min=0, max=height - 1)
142
            v_sa_hat_ed_y = torch.clamp(v_sa_hat_ed[:, :, 1:2], min=0, max=width - 1)
143
            v_sa_hat_ed_cp = torch.cat((v_sa_hat_ed_x, v_sa_hat_ed_y, v_sa_hat_ed[:, :, 2:3]), 2)
144
145
146
            # project to LAX 2CH view
147
            v_2ch_hat_ed_x = torch.clamp(v_2ch_hat_ed[:, :, 0:1], min=0, max=height - 1)
148
            v_2ch_hat_ed_y = torch.clamp(v_2ch_hat_ed[:, :, 1:2], min=0, max=width - 1)
149
            v_2ch_hat_ed_cp = torch.cat((v_2ch_hat_ed_x, v_2ch_hat_ed_y, v_2ch_hat_ed[:, :, 2:3]), 2)
150
151
152
            # project to LAX 4CH view
153
            v_4ch_hat_ed_x = torch.clamp(v_4ch_hat_ed[:, :, 0:1], min=0, max=height - 1)
154
            v_4ch_hat_ed_y = torch.clamp(v_4ch_hat_ed[:, :, 1:2], min=0, max=width - 1)
155
            v_4ch_hat_ed_cp = torch.cat((v_4ch_hat_ed_x, v_4ch_hat_ed_y, v_4ch_hat_ed[:, :, 2:3]), 2)
156
157
158
159
            # slicer
160
161
            mcd_sa, hd_sa = compute_sa_mcd_hd(v_sa_hat_ed_cp, contour_sa_ed, sa_sliceall)
162
            bfscore_sa = compute_sa_Fboundary(v_sa_hat_ed_cp, contour_sa_ed, sa_sliceall, height, width)
163
164
165
            idx_2ch = slice_2D(v_2ch_hat_ed_cp, 0)
166
            idx_2ch_gt = np.stack(np.nonzero(contour_2ch_ed), 1)
167
            mcd_2ch, hd_2ch = distance_metric(idx_2ch, idx_2ch_gt, 1.25)
168
            la_2ch_pred_con = np.zeros(shape=(height, width), dtype=np.uint8)
169
            for j in range(idx_2ch.shape[0]):
170
                la_2ch_pred_con[idx_2ch[j, 0], idx_2ch[j, 1]] = 1
171
            bfscore_2ch = compute_la_Fboundary(la_2ch_pred_con, contour_2ch_ed)
172
173
174
            idx_4ch = slice_2D(v_4ch_hat_ed_cp, 0)
175
            idx_4ch_gt = np.stack(np.nonzero(contour_4ch_ed), 1)
176
            mcd_4ch, hd_4ch = distance_metric(idx_4ch, idx_4ch_gt, 1.25)
177
            la_4ch_pred_con = np.zeros(shape=(height, width), dtype=np.uint8)
178
            for j in range(idx_4ch.shape[0]):
179
                la_4ch_pred_con[idx_4ch[j, 0], idx_4ch[j, 1]] = 1
180
            bfscore_4ch = compute_la_Fboundary(la_4ch_pred_con, contour_4ch_ed)
181
182
183
            if (hd_sa != None):
184
                hd_SA.append(hd_sa)
185
            if (hd_2ch != None):
186
                hd_2CH.append(hd_2ch)
187
            if (hd_4ch != None):
188
                hd_4CH.append(hd_4ch)
189
190
            if (bfscore_sa != None):
191
                bfscore_SA.append(bfscore_sa)
192
            if (bfscore_2ch != None):
193
                bfscore_2CH.append(bfscore_2ch)
194
            if (bfscore_4ch != None):
195
                bfscore_4CH.append(bfscore_4ch)
196
197
198
            print (hd_sa, hd_2ch, hd_4ch)
199
            print (bfscore_sa, bfscore_2ch, bfscore_4ch)
200
201
202
203
    print('SA HD: {:.4f}({:.4f}), 2CH HD: {:.4f}({:.4f}), 4CH HD: {:.4f}({:.4f})'
204
          .format(np.mean(hd_SA), np.std(hd_SA), np.mean(hd_2CH), np.std(hd_2CH), np.mean(hd_4CH), np.std(hd_4CH)))
205
    print('SA BFscore: {:.4f}({:.4f}), 2CH BFscore: {:.4f}({:.4f}), 4CH BFscore: {:.4f}({:.4f})'
206
          .format(np.mean(bfscore_SA), np.std(bfscore_SA), np.mean(bfscore_2CH), np.std(bfscore_2CH),
207
                  np.mean(bfscore_4CH), np.std(bfscore_4CH)))
208
209
210
211
212
test_data_path = '/test_data_path'
213
214
215
DeformNet = deformnet().cuda()
216
MV_LA = Mesh_2d().cuda()
217
218
DeformNet.load_state_dict(torch.load(Deform_save_path), strict=True)
219
MV_LA.load_state_dict(torch.load(Motion_LA_save_path), strict=True)
220
221
Tensor = torch.cuda.FloatTensor
222
TensorLong = torch.cuda.LongTensor
223
224
225
start = time.time()
226
test(test_data_path)
227
end = time.time()
228
print("testing took {:.8f}".format(end - start))