Diff of /train_motion.py [000000] .. [721c7a]

Switch to unified view

a b/train_motion.py
1
import torch.nn as nn
2
import numpy as np
3
import itertools
4
import os
5
import sys
6
7
import torch.optim as optim
8
from torch.utils.data import DataLoader
9
from torch.autograd import Variable
10
from tqdm import tqdm
11
import time
12
from torch.utils.tensorboard import SummaryWriter
13
from pytorch3d.structures import Meshes
14
from pytorch3d import loss
15
16
from network_motion import *
17
from dataio_motion import *
18
from utils import *
19
20
21
lr = 1e-4
22
n_worker = 4
23
bs = 8
24
n_epoch = 400
25
base_err = 1000
26
27
w_smooth = 150
28
w_reg = 20
29
w_h = 0.5
30
width = 128
31
height = 128
32
depth = 64
33
sa_idx = [12, 17, 22, 27, 32, 37, 42, 47, 52]
34
temper = 3
35
36
37
38
model_save_path = './models/model_motion'
39
if not os.path.exists(model_save_path):
40
    os.makedirs(model_save_path)
41
42
# pytorch only saves the last model
43
Motion_save_path = os.path.join(model_save_path, 'motionEst.pth')
44
Motion_LA_save_path = os.path.join(model_save_path, 'multiview.pth')
45
46
flow_criterion = nn.MSELoss()
47
MotionNet = MotionMesh_25d().cuda()
48
MV_LA = Mesh_2d().cuda()
49
50
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
51
                              itertools.chain(MotionNet.parameters(), MV_LA.parameters())), lr=lr)
52
Tensor = torch.cuda.FloatTensor
53
TensorLong = torch.cuda.LongTensor
54
55
# visualisation
56
writer = SummaryWriter('./runs/model_motion')
57
58
59
60
def train(epoch):
61
    MotionNet.train()
62
    MV_LA.train()
63
64
    epoch_loss = []
65
    epoch_seg_loss = []
66
    epoch_smooth_loss = []
67
    epoch_reg_loss = []
68
    epoch_huber_loss = []
69
70
    Myo_dice_sa = []
71
    Myo_dice_2ch = []
72
    Myo_dice_4ch = []
73
74
75
    for batch_idx, batch in tqdm(enumerate(training_data_loader, 1),
76
                                 total=len(training_data_loader)):
77
78
        img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, \
79
        contour_sa, contour_2ch, contour_4ch, \
80
        vertex_ed, faces, affine_inv, affine, origin = batch
81
82
83
        x_sa_t = Variable(img_sa_t.type(Tensor))
84
        x_sa_ed = Variable(img_sa_ed.type(Tensor))
85
        x_2ch_t = Variable(img_2ch_t.type(Tensor))
86
        x_2ch_ed = Variable(img_2ch_ed.type(Tensor))
87
        x_4ch_t = Variable(img_4ch_t.type(Tensor))
88
        x_4ch_ed = Variable(img_4ch_ed.type(Tensor))
89
90
        x_sa_t_5D = Variable(img_sa_t.unsqueeze(1).type(Tensor))
91
        x_sa_ed_5D = Variable(img_sa_ed.unsqueeze(1).type(Tensor))
92
93
94
        con_sa = Variable(contour_sa.type(TensorLong))  # [bs, slices, H, W]
95
        con_2ch = Variable(contour_2ch.type(TensorLong))  # [bs, H, W]
96
        con_4ch = Variable(contour_4ch.type(TensorLong))  # [bs, H, W]
97
98
        aff_sa_inv = Variable(affine_inv[:, 0,:,:].type(Tensor))
99
        aff_sa = Variable(affine[:, 0,:,:].type(Tensor))
100
        aff_2ch_inv = Variable(affine_inv[:, 1,:,:].type(Tensor))
101
        aff_4ch_inv = Variable(affine_inv[:, 2,:,:].type(Tensor))
102
103
        origin_sa = Variable(origin[:, 0:1, :].type(Tensor))
104
        origin_2ch = Variable(origin[:, 1:2, :].type(Tensor))
105
        origin_4ch = Variable(origin[:, 2:3, :].type(Tensor))
106
107
        vertex_0 = Variable(vertex_ed.permute(0,2,1).type(Tensor)) # [bs, 3, number of vertices]
108
        faces_0 = Variable(faces.type(Tensor)) # [bs, number of faces, 3]
109
110
111
        optimizer.zero_grad()
112
113
        net_la = MV_LA(x_2ch_t, x_2ch_ed, x_4ch_t, x_4ch_ed)
114
        net_sa = MotionNet(x_sa_t, x_sa_ed, net_la['conv2_2ch'], net_la['conv2s_2ch'], net_la['conv2_4ch'], net_la['conv2s_4ch'])
115
116
        # ---------------sample from 3D motion fields
117
        #translate coordinate
118
        v_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], vertex_0) + aff_sa_inv[:, :3, 3:4]
119
        v_ed = v_ed_o.permute(0, 2, 1) - origin_sa  # [bs, number of vertices,3]
120
121
        # normalize translated coordinate (image space) to [-1,1]
122
        v_ed_x = (v_ed[:, :, 0:1] - (width / 2)) / (width / 2)
123
        v_ed_y = (v_ed[:, :, 1:2] - (height / 2)) / (height / 2)
124
        v_ed_z = (v_ed[:, :, 2:3] - (depth / 2)) / (depth / 2)
125
        v_ed_norm = torch.cat((v_ed_x, v_ed_y, v_ed_z), 2)
126
        v_ed_norm_expand = v_ed_norm.unsqueeze(1).unsqueeze(1)  # [bs, 1, 1,number of vertices,3]
127
128
        # sample from 3D motion field
129
        pxx = F.grid_sample(net_sa['out'][:, 0:1], v_ed_norm_expand, align_corners=True).transpose(4, 3)
130
        pyy = F.grid_sample(net_sa['out'][:, 1:2], v_ed_norm_expand, align_corners=True).transpose(4, 3)
131
        pzz = F.grid_sample(net_sa['out'][:, 2:3], v_ed_norm_expand, align_corners=True).transpose(4, 3)
132
        delta_p = torch.cat((pxx, pyy, pzz), 4)
133
        # updata coor (image space)
134
        # print (v_ed.shape, delta_p.shape)
135
        v_t_norm_expand = v_ed_norm_expand + delta_p  # [bs, 1, 1,number of vertices,3]
136
        # t frame
137
        v_t_norm = v_t_norm_expand.squeeze(1).squeeze(1)
138
        v_t_x = v_t_norm[:, :, 0:1] * (width / 2) + (width / 2)
139
        v_t_y = v_t_norm[:, :, 1:2] * (height / 2) + (height / 2)
140
        v_t_z = v_t_norm[:, :, 2:3] * (depth / 2) + (depth / 2)
141
        v_t_crop = torch.cat((v_t_x, v_t_y, v_t_z), 2)
142
        # translate back to mesh space
143
        v_t = v_t_crop + origin_sa  # [bs, number of vertices,3]
144
        pred_vertex_t = torch.matmul(aff_sa[:, :3, :3], v_t.permute(0,2,1)) + aff_sa[:, :3, 3:4] # [bs, 3, number of vertices]
145
        # print (pred_vertex_t.shape)
146
147
148
        pred_sa_ed = transform(x_sa_t_5D, net_sa['out'], mode='bilinear')
149
150
        # -------------- differentialable slicer
151
152
        # coordinate transformation np.dot(aff_sa_SR_inv[:3,:3], points_ED.T) + aff_sa_SR_inv[:3,3:4]
153
        v_sa_hat_t_o = torch.matmul(aff_sa_inv[:, :3, :3], pred_vertex_t) + aff_sa_inv[:, :3, 3:4]
154
        v_sa_hat_t = v_sa_hat_t_o.permute(0, 2, 1) - origin_sa
155
        # print (v_sa_hat_t.shape)
156
        v_2ch_hat_t_o = torch.matmul(aff_2ch_inv[:, :3, :3], pred_vertex_t) + aff_2ch_inv[:, :3, 3:4]
157
        v_2ch_hat_t = v_2ch_hat_t_o.permute(0, 2, 1) - origin_2ch
158
        v_4ch_hat_t_o = torch.matmul(aff_4ch_inv[:, :3, :3], pred_vertex_t) + aff_4ch_inv[:, :3, 3:4]
159
        v_4ch_hat_t = v_4ch_hat_t_o.permute(0,2, 1) - origin_4ch
160
161
        # project vertices satisfying threshood
162
        # project to SAX slices, project all vertices to a target plane,
163
        # vertices selection is moved to loss computation function
164
        v_sa_hat_t_x = torch.clamp(v_sa_hat_t[:, :, 0:1], min=0, max=height - 1)
165
        v_sa_hat_t_y = torch.clamp(v_sa_hat_t[:, :, 1:2], min=0, max=width - 1)
166
        v_sa_hat_t_cp = torch.cat((v_sa_hat_t_x, v_sa_hat_t_y, v_sa_hat_t[:, :, 2:3]), 2)
167
168
        v_sa_idx_t_0, w_sa_t_0 = projection(v_sa_hat_t_cp, 12, temper)
169
        # print (v_sa_idx_ed_0.shape, w_sa_ed_0.shape)
170
        v_sa_idx_t_1, w_sa_t_1 = projection(v_sa_hat_t_cp, 17, temper)
171
        v_sa_idx_t_2, w_sa_t_2 = projection(v_sa_hat_t_cp, 22, temper)
172
        v_sa_idx_t_3, w_sa_t_3 = projection(v_sa_hat_t_cp, 27, temper)
173
        v_sa_idx_t_4, w_sa_t_4 = projection(v_sa_hat_t_cp, 32, temper)
174
        v_sa_idx_t_5, w_sa_t_5 = projection(v_sa_hat_t_cp, 37, temper)
175
        v_sa_idx_t_6, w_sa_t_6 = projection(v_sa_hat_t_cp, 42, temper)
176
        v_sa_idx_t_7, w_sa_t_7 = projection(v_sa_hat_t_cp, 47, temper)
177
        v_sa_idx_t_8, w_sa_t_8 = projection(v_sa_hat_t_cp, 52, temper)
178
179
        # project to LAX 2CH view
180
        v_2ch_hat_t_x = torch.clamp(v_2ch_hat_t[:, :, 0:1], min=0, max=height - 1)
181
        v_2ch_hat_t_y = torch.clamp(v_2ch_hat_t[:, :, 1:2], min=0, max=width - 1)
182
        v_2ch_hat_t_cp = torch.cat((v_2ch_hat_t_x, v_2ch_hat_t_y, v_2ch_hat_t[:, :, 2:3]), 2)
183
184
        v_2ch_idx_t, w_2ch_t = projection(v_2ch_hat_t_cp, 0, temper)
185
186
187
        # project to LAX 4CH view
188
        v_4ch_hat_t_x = torch.clamp(v_4ch_hat_t[:, :, 0:1], min=0, max=height - 1)
189
        v_4ch_hat_t_y = torch.clamp(v_4ch_hat_t[:, :, 1:2], min=0, max=width - 1)
190
        v_4ch_hat_t_cp = torch.cat((v_4ch_hat_t_x, v_4ch_hat_t_y, v_4ch_hat_t[:, :, 2:3]), 2)
191
192
        v_4ch_idx_t, w_4ch_t = projection(v_4ch_hat_t_cp, 0, temper)
193
194
195
196
        # --------------------- Segmentation loss------------------
197
        loss_seg_sa_t_0 = weightedHausdorff_batch(v_sa_idx_t_0, w_sa_t_0, con_sa[:, 0, :, :], height, width, temper,
198
                                                  'train')
199
        loss_seg_sa_t_1 = weightedHausdorff_batch(v_sa_idx_t_1, w_sa_t_1, con_sa[:, 1, :, :], height, width, temper,
200
                                                  'train')
201
        loss_seg_sa_t_2 = weightedHausdorff_batch(v_sa_idx_t_2, w_sa_t_2, con_sa[:, 2, :, :], height, width, temper,
202
                                                  'train')
203
        loss_seg_sa_t_3 = weightedHausdorff_batch(v_sa_idx_t_3, w_sa_t_3, con_sa[:, 3, :, :], height, width, temper,
204
                                                  'train')
205
        loss_seg_sa_t_4 = weightedHausdorff_batch(v_sa_idx_t_4, w_sa_t_4, con_sa[:, 4, :, :], height, width, temper,
206
                                                  'train')
207
        loss_seg_sa_t_5 = weightedHausdorff_batch(v_sa_idx_t_5, w_sa_t_5, con_sa[:, 5, :, :], height, width, temper,
208
                                                  'train')
209
        loss_seg_sa_t_6 = weightedHausdorff_batch(v_sa_idx_t_6, w_sa_t_6, con_sa[:, 6, :, :], height, width, temper,
210
                                                  'train')
211
        loss_seg_sa_t_7 = weightedHausdorff_batch(v_sa_idx_t_7, w_sa_t_7, con_sa[:, 7, :, :], height, width, temper,
212
                                                  'train')
213
        loss_seg_sa_t_8 = weightedHausdorff_batch(v_sa_idx_t_8, w_sa_t_8, con_sa[:, 8, :, :], height, width, temper,
214
                                                  'train')
215
        loss_seg_2ch_t = weightedHausdorff_batch(v_2ch_idx_t, w_2ch_t, con_2ch, height, width, temper, 'train')
216
        loss_seg_4ch_t = weightedHausdorff_batch(v_4ch_idx_t, w_4ch_t, con_4ch, height, width, temper, 'train')
217
218
        loss_seg = (loss_seg_sa_t_0 + loss_seg_sa_t_1 + loss_seg_sa_t_2 + loss_seg_sa_t_3 +
219
                      loss_seg_sa_t_4 + loss_seg_sa_t_5 + loss_seg_sa_t_6 + loss_seg_sa_t_7 + loss_seg_sa_t_8) / 9.0 + \
220
                     loss_seg_2ch_t + loss_seg_4ch_t
221
222
223
224
225
        #----------------smoothness loss------------
226
        # print (pred_vertex_t.permute(0,2,1).shape)
227
        trg_mesh = Meshes(verts=list(pred_vertex_t.permute(0, 2, 1)), faces=list(faces_0))
228
        loss_smooth = loss.mesh_laplacian_smoothing(trg_mesh, method='uniform')
229
230
        # ----------------regularization loss------------
231
232
        # define image registration as a regularization term
233
        loss_reg = flow_criterion(pred_sa_ed, x_sa_ed_5D)
234
235
        loss_huber = huber_loss_3d(net_sa['out'])
236
237
238
        loss_all = loss_seg + w_reg*loss_reg + w_smooth * loss_smooth + w_h * loss_huber
239
240
        loss_all.backward()
241
        optimizer.step()
242
243
244
        epoch_loss.append(loss_all.item())
245
        epoch_seg_loss.append(loss_seg.item())
246
        epoch_smooth_loss.append(loss_smooth.item())
247
        epoch_reg_loss.append(loss_reg.item())
248
        epoch_huber_loss.append(loss_huber.item())
249
250
251
252
        # tensorboard visulisation
253
        writer.add_scalar("Loss/train", loss_all, epoch * len(training_data_loader) + batch_idx)
254
        writer.add_scalar("Loss/train_seg", loss_seg, epoch * len(training_data_loader) + batch_idx)
255
        writer.add_scalar("Loss/train_reg", loss_reg, epoch * len(training_data_loader) + batch_idx)
256
        writer.add_scalar("Loss/train_smooth", loss_smooth, epoch * len(training_data_loader) + batch_idx)
257
        writer.add_scalar("Loss/train_huber", loss_huber, epoch * len(training_data_loader) + batch_idx)
258
259
260
        if batch_idx % 40 == 0:
261
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss all: {:.6f}, '
262
                  'Seg Loss: {:.6f}, Reg Loss: {:.6f}, Smooth Loss: {:.6f}, Huber Loss: {:.6f}'.format(
263
                epoch, batch_idx * len(img_sa_t), len(training_data_loader.dataset),
264
                100. * batch_idx / len(training_data_loader), np.mean(epoch_loss),
265
                np.mean(epoch_seg_loss), np.mean(epoch_reg_loss), np.mean(epoch_smooth_loss), np.mean(epoch_huber_loss), np.mean(Myo_dice_sa), np.mean(Myo_dice_2ch), np.mean(Myo_dice_4ch)))
266
267
            # torch.save(model.state_dict(), model_save_path)
268
            # print("Checkpoint saved to {}".format(model_save_path))
269
270
def val(epoch):
271
    MotionNet.eval()
272
    MV_LA.eval()
273
274
    val_loss = []
275
    val_seg_loss = []
276
    val_smooth_loss = []
277
    val_reg_loss = []
278
    val_huber_loss = []
279
280
    global base_err
281
    for batch_idx, batch in tqdm(enumerate(val_data_loader, 1),
282
                                 total=len(val_data_loader)):
283
284
        img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, \
285
        contour_sa, contour_2ch, contour_4ch, \
286
        vertex_ed, faces, affine_inv, affine, origin = batch
287
288
        with torch.no_grad():
289
290
            x_sa_t = img_sa_t.type(Tensor)
291
            x_sa_ed = img_sa_ed.type(Tensor)
292
            x_2ch_t = img_2ch_t.type(Tensor)
293
            x_2ch_ed = img_2ch_ed.type(Tensor)
294
            x_4ch_t = img_4ch_t.type(Tensor)
295
            x_4ch_ed = img_4ch_ed.type(Tensor)
296
297
            x_sa_t_5D = img_sa_t.unsqueeze(1).type(Tensor)
298
            x_sa_ed_5D = img_sa_ed.unsqueeze(1).type(Tensor)
299
300
301
            con_sa = contour_sa.type(TensorLong)  # [bs, slices, H, W]
302
            con_2ch = contour_2ch.type(TensorLong)  # [bs, H, W]
303
            con_4ch = contour_4ch.type(TensorLong)  # [bs, H, W]
304
305
            aff_sa_inv = affine_inv[:, 0, :, :].type(Tensor)
306
            aff_sa = affine[:, 0, :, :].type(Tensor)
307
            aff_2ch_inv = affine_inv[:, 1, :, :].type(Tensor)
308
            aff_4ch_inv = affine_inv[:, 2, :, :].type(Tensor)
309
310
            origin_sa = origin[:, 0:1, :].type(Tensor)
311
            origin_2ch = origin[:, 1:2, :].type(Tensor)
312
            origin_4ch = origin[:, 2:3, :].type(Tensor)
313
314
            vertex_0 = vertex_ed.permute(0, 2, 1).type(Tensor)  # [bs, 3, number of vertices]
315
            faces_0 = faces.type(Tensor)  # [bs, number of faces, 3]
316
317
318
            net_la = MV_LA(x_2ch_t, x_2ch_ed, x_4ch_t, x_4ch_ed)
319
            net_sa = MotionNet(x_sa_t, x_sa_ed, net_la['conv2_2ch'], net_la['conv2s_2ch'], net_la['conv2_4ch'],
320
                               net_la['conv2s_4ch'])
321
322
            # ---------------sample from 3D motion fields
323
            # translate coordinate
324
            v_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], vertex_0) + aff_sa_inv[:, :3, 3:4]
325
            v_ed = v_ed_o.permute(0, 2, 1) - origin_sa  # [bs, number of vertices,3]
326
            v_ed_x = (v_ed[:, :, 0:1] - (width / 2)) / (width / 2)
327
            v_ed_y = (v_ed[:, :, 1:2] - (height / 2)) / (height / 2)
328
            v_ed_z = (v_ed[:, :, 2:3] - (depth / 2)) / (depth / 2)
329
            v_ed_norm = torch.cat((v_ed_x, v_ed_y, v_ed_z), 2)
330
            v_ed_norm_expand = v_ed_norm.unsqueeze(1).unsqueeze(1)  # [bs, 1, 1,number of vertices,3]
331
332
            # sample from 3D motion field
333
            pxx = F.grid_sample(net_sa['out'][:, 0:1], v_ed_norm_expand, align_corners=True).transpose(4, 3)
334
            pyy = F.grid_sample(net_sa['out'][:, 1:2], v_ed_norm_expand, align_corners=True).transpose(4, 3)
335
            pzz = F.grid_sample(net_sa['out'][:, 2:3], v_ed_norm_expand, align_corners=True).transpose(4, 3)
336
            delta_p = torch.cat((pxx, pyy, pzz), 4)
337
            # updata coor (image space)
338
            # print (v_ed.shape, delta_p.shape)
339
            v_t_norm_expand = v_ed_norm_expand + delta_p  # [bs, 1, 1,number of vertices,3]
340
            # t frame
341
            v_t_norm = v_t_norm_expand.squeeze(1).squeeze(1)
342
            v_t_x = v_t_norm[:, :, 0:1] * (width / 2) + (width / 2)
343
            v_t_y = v_t_norm[:, :, 1:2] * (height / 2) + (height / 2)
344
            v_t_z = v_t_norm[:, :, 2:3] * (depth / 2) + (depth / 2)
345
            v_t_crop = torch.cat((v_t_x, v_t_y, v_t_z), 2)
346
            # translate back to mesh space
347
            v_t = v_t_crop + origin_sa  # [bs, number of vertices,3]
348
            pred_vertex_t = torch.matmul(aff_sa[:, :3, :3], v_t.permute(0, 2, 1)) + aff_sa[:, :3,
349
                                                                                    3:4]  # [bs, 3, number of vertices]
350
            # print (pred_vertex_t.shape)
351
352
353
            pred_sa_ed = transform(x_sa_t_5D, net_sa['out'], mode='bilinear')
354
355
            # -------------- differentialable slicer
356
357
            # coordinate transformation np.dot(aff_sa_SR_inv[:3,:3], points_ED.T) + aff_sa_SR_inv[:3,3:4]
358
            v_sa_hat_t_o = torch.matmul(aff_sa_inv[:, :3, :3], pred_vertex_t) + aff_sa_inv[:, :3, 3:4]
359
            v_sa_hat_t = v_sa_hat_t_o.permute(0, 2, 1) - origin_sa
360
            # print (v_sa_hat_t.shape)
361
            v_2ch_hat_t_o = torch.matmul(aff_2ch_inv[:, :3, :3], pred_vertex_t) + aff_2ch_inv[:, :3, 3:4]
362
            v_2ch_hat_t = v_2ch_hat_t_o.permute(0, 2, 1) - origin_2ch
363
            v_4ch_hat_t_o = torch.matmul(aff_4ch_inv[:, :3, :3], pred_vertex_t) + aff_4ch_inv[:, :3, 3:4]
364
            v_4ch_hat_t = v_4ch_hat_t_o.permute(0, 2, 1) - origin_4ch
365
366
            # project vertices satisfying threshood
367
            # project to SAX slices, project all vertices to a target plane,
368
            # vertices selection is moved to loss computation function
369
            v_sa_hat_t_x = torch.clamp(v_sa_hat_t[:, :, 0:1], min=0, max=height - 1)
370
            v_sa_hat_t_y = torch.clamp(v_sa_hat_t[:, :, 1:2], min=0, max=width - 1)
371
            v_sa_hat_t_cp = torch.cat((v_sa_hat_t_x, v_sa_hat_t_y, v_sa_hat_t[:, :, 2:3]), 2)
372
373
            v_sa_idx_t_0, w_sa_t_0 = projection(v_sa_hat_t_cp, 12, temper)
374
            # print (v_sa_idx_ed_0.shape, w_sa_ed_0.shape)
375
            v_sa_idx_t_1, w_sa_t_1 = projection(v_sa_hat_t_cp, 17, temper)
376
            v_sa_idx_t_2, w_sa_t_2 = projection(v_sa_hat_t_cp, 22, temper)
377
            v_sa_idx_t_3, w_sa_t_3 = projection(v_sa_hat_t_cp, 27, temper)
378
            v_sa_idx_t_4, w_sa_t_4 = projection(v_sa_hat_t_cp, 32, temper)
379
            v_sa_idx_t_5, w_sa_t_5 = projection(v_sa_hat_t_cp, 37, temper)
380
            v_sa_idx_t_6, w_sa_t_6 = projection(v_sa_hat_t_cp, 42, temper)
381
            v_sa_idx_t_7, w_sa_t_7 = projection(v_sa_hat_t_cp, 47, temper)
382
            v_sa_idx_t_8, w_sa_t_8 = projection(v_sa_hat_t_cp, 52, temper)
383
384
            # project to LAX 2CH view
385
            v_2ch_hat_t_x = torch.clamp(v_2ch_hat_t[:, :, 0:1], min=0, max=height - 1)
386
            v_2ch_hat_t_y = torch.clamp(v_2ch_hat_t[:, :, 1:2], min=0, max=width - 1)
387
            v_2ch_hat_t_cp = torch.cat((v_2ch_hat_t_x, v_2ch_hat_t_y, v_2ch_hat_t[:, :, 2:3]), 2)
388
389
            v_2ch_idx_t, w_2ch_t = projection(v_2ch_hat_t_cp, 0, temper)
390
391
            # project to LAX 4CH view
392
            v_4ch_hat_t_x = torch.clamp(v_4ch_hat_t[:, :, 0:1], min=0, max=height - 1)
393
            v_4ch_hat_t_y = torch.clamp(v_4ch_hat_t[:, :, 1:2], min=0, max=width - 1)
394
            v_4ch_hat_t_cp = torch.cat((v_4ch_hat_t_x, v_4ch_hat_t_y, v_4ch_hat_t[:, :, 2:3]), 2)
395
396
            v_4ch_idx_t, w_4ch_t = projection(v_4ch_hat_t_cp, 0, temper)
397
398
            # --------------------- Segmentation loss------------------
399
            loss_seg_sa_t_0 = weightedHausdorff_batch(v_sa_idx_t_0, w_sa_t_0, con_sa[:, 0, :, :], height, width, temper,
400
                                                      'val')
401
            loss_seg_sa_t_1 = weightedHausdorff_batch(v_sa_idx_t_1, w_sa_t_1, con_sa[:, 1, :, :], height, width, temper,
402
                                                      'val')
403
            loss_seg_sa_t_2 = weightedHausdorff_batch(v_sa_idx_t_2, w_sa_t_2, con_sa[:, 2, :, :], height, width, temper,
404
                                                      'val')
405
            loss_seg_sa_t_3 = weightedHausdorff_batch(v_sa_idx_t_3, w_sa_t_3, con_sa[:, 3, :, :], height, width, temper,
406
                                                      'val')
407
            loss_seg_sa_t_4 = weightedHausdorff_batch(v_sa_idx_t_4, w_sa_t_4, con_sa[:, 4, :, :], height, width, temper,
408
                                                      'val')
409
            loss_seg_sa_t_5 = weightedHausdorff_batch(v_sa_idx_t_5, w_sa_t_5, con_sa[:, 5, :, :], height, width, temper,
410
                                                      'val')
411
            loss_seg_sa_t_6 = weightedHausdorff_batch(v_sa_idx_t_6, w_sa_t_6, con_sa[:, 6, :, :], height, width, temper,
412
                                                      'val')
413
            loss_seg_sa_t_7 = weightedHausdorff_batch(v_sa_idx_t_7, w_sa_t_7, con_sa[:, 7, :, :], height, width, temper,
414
                                                      'val')
415
            loss_seg_sa_t_8 = weightedHausdorff_batch(v_sa_idx_t_8, w_sa_t_8, con_sa[:, 8, :, :], height, width, temper,
416
                                                      'val')
417
            loss_seg_2ch_t = weightedHausdorff_batch(v_2ch_idx_t, w_2ch_t, con_2ch, height, width, temper, 'val')
418
            loss_seg_4ch_t = weightedHausdorff_batch(v_4ch_idx_t, w_4ch_t, con_4ch, height, width, temper, 'val')
419
420
            loss_seg = (loss_seg_sa_t_0 + loss_seg_sa_t_1 + loss_seg_sa_t_2 + loss_seg_sa_t_3 +
421
                        loss_seg_sa_t_4 + loss_seg_sa_t_5 + loss_seg_sa_t_6 + loss_seg_sa_t_7 + loss_seg_sa_t_8) / 9.0 + \
422
                       loss_seg_2ch_t + loss_seg_4ch_t
423
424
            # ----------------smoothness loss------------
425
            # print (pred_vertex_t.permute(0,2,1).shape)
426
            trg_mesh = Meshes(verts=list(pred_vertex_t.permute(0, 2, 1)), faces=list(faces_0))
427
            loss_smooth = loss.mesh_laplacian_smoothing(trg_mesh, method='uniform')
428
429
            # ----------------regularization loss------------
430
431
            loss_reg = flow_criterion(pred_sa_ed, x_sa_ed_5D)
432
433
            loss_huber = huber_loss_3d(net_sa['out'])
434
435
            loss_all = loss_seg + w_reg * loss_reg + w_smooth * loss_smooth + w_h * loss_huber
436
437
438
            val_loss.append(loss_all.item())
439
            val_seg_loss.append(loss_seg.item())
440
            val_smooth_loss.append(loss_smooth.item())
441
            val_reg_loss.append(loss_reg.item())
442
            val_huber_loss.append(loss_huber.item())
443
444
            if batch_idx == 1:
445
                # tensorboard visulisation
446
                writer.add_scalar("Loss/val", loss_all, epoch * len(training_data_loader) + batch_idx)
447
                writer.add_scalar("Loss/val_seg", loss_seg, epoch * len(training_data_loader) + batch_idx)
448
                writer.add_scalar("Loss/val_reg", loss_reg, epoch * len(training_data_loader) + batch_idx)
449
                writer.add_scalar("Loss/val_smooth", loss_smooth, epoch * len(training_data_loader) + batch_idx)
450
                writer.add_scalar("Loss/val_huber", loss_huber, epoch * len(training_data_loader) + batch_idx)
451
452
453
    if np.mean(val_loss) < base_err:
454
        torch.save(MotionNet.state_dict(), Motion_save_path)
455
        torch.save(MV_LA.state_dict(), Motion_LA_save_path)
456
        base_err = np.mean(val_loss)
457
458
459
460
data_path = '/train_data_path'
461
train_set = TrainDataset(data_path)
462
# loading the data
463
training_data_loader = DataLoader(dataset=train_set, num_workers=n_worker, batch_size=bs, shuffle=True)
464
465
val_data_path = '/val_data_pathl'
466
val_set = ValDataset(val_data_path)
467
val_data_loader = DataLoader(dataset=val_set, num_workers=n_worker, batch_size=bs, shuffle=False)
468
469
470
for epoch in range(0, n_epoch + 1):
471
    start = time.time()
472
    train(epoch)
473
    end = time.time()
474
    print("training took {:.8f}".format(end-start))
475
476
    print('Epoch {}'.format(epoch))
477
    start = time.time()
478
    val(epoch)
479
    end = time.time()
480
    print("validation took {:.8f}".format(end - start))