Diff of /train_reconstruction.py [000000] .. [721c7a]

Switch to unified view

a b/train_reconstruction.py
1
import torch.nn as nn
2
import numpy as np
3
import itertools
4
import os
5
import sys
6
7
import torch.optim as optim
8
from torch.utils.data import DataLoader
9
from torch.autograd import Variable
10
from tqdm import tqdm
11
import time
12
from torch.utils.tensorboard import SummaryWriter
13
from pytorch3d.structures import Meshes
14
from pytorch3d import loss
15
16
from network_reconstruction import *
17
from dataio_reconstruction import *
18
from utils import *
19
20
21
lr = 1e-4
22
n_worker = 4
23
bs = 5
24
n_epoch = 400
25
base_err = 10000
26
27
w_smooth = 20
28
w_surface = 0.5
29
w_h = 0.5
30
width = 128
31
height = 128
32
depth = 64
33
temper = 3
34
35
36
37
model_save_path = './models/model_reconstruction'
38
if not os.path.exists(model_save_path):
39
    os.makedirs(model_save_path)
40
41
# pytorch only saves the last model
42
Deform_save_path = os.path.join(model_save_path, 'deform.pth')
43
Motion_LA_save_path = os.path.join(model_save_path, 'multiview.pth')
44
45
DeformNet = deformnet().cuda()
46
MV_LA = Mesh_2d().cuda()
47
48
49
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
50
                              itertools.chain(DeformNet.parameters(), MV_LA.parameters())), lr=lr)
51
Tensor = torch.cuda.FloatTensor
52
TensorLong = torch.cuda.LongTensor
53
54
# visualisation
55
writer = SummaryWriter('./runs/model_reconstruction')
56
57
58
59
def train(epoch):
60
    DeformNet.train()
61
    MV_LA.train()
62
63
    epoch_loss = []
64
    epoch_seg_loss = []
65
    epoch_smooth_loss = []
66
    epoch_surface_loss = []
67
    epoch_huber_loss = []
68
69
70
71
    for batch_idx, batch in tqdm(enumerate(training_data_loader, 1),
72
                                 total=len(training_data_loader)):
73
74
        img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
75
        vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = batch
76
77
78
        x_sa_ed = Variable(img_sa_ed.type(Tensor))
79
        x_2ch_t = Variable(img_2ch_t.type(Tensor))
80
        x_2ch_ed = Variable(img_2ch_ed.type(Tensor))
81
        x_4ch_t = Variable(img_4ch_t.type(Tensor))
82
        x_4ch_ed = Variable(img_4ch_ed.type(Tensor))
83
84
85
        aff_sa_inv = Variable(affine_inv[:, 0,:,:].type(Tensor))
86
        aff_sa = Variable(affine[:, 0,:,:].type(Tensor))
87
        aff_2ch_inv = Variable(affine_inv[:, 1,:,:].type(Tensor))
88
        aff_4ch_inv = Variable(affine_inv[:, 2,:,:].type(Tensor))
89
90
        origin_sa = Variable(origin[:, 0:1, :].type(Tensor))
91
        origin_2ch = Variable(origin[:, 1:2, :].type(Tensor))
92
        origin_4ch = Variable(origin[:, 2:3, :].type(Tensor))
93
94
        vertex_tpl_0 = Variable(vertex_tpl_ed.permute(0,2,1).type(Tensor)) # [bs, 3, number of vertices]
95
        faces_tpl_0 = Variable(faces_tpl.type(Tensor)) # [bs, number of faces, 3]
96
        vertex_0 = Variable(vertex_ed.permute(0, 2, 1).type(Tensor))  # [bs, 3, number of vertices]
97
98
99
        mesh2seg_sa_gt = Variable(mesh2seg_sa.type(Tensor))
100
        mesh2seg_2ch_gt = Variable(mesh2seg_2ch.type(Tensor))
101
        mesh2seg_4ch_gt = Variable(mesh2seg_4ch.type(Tensor))
102
103
104
105
        optimizer.zero_grad()
106
107
        net_la = MV_LA(x_2ch_t, x_2ch_ed, x_4ch_t, x_4ch_ed)
108
        net_df = DeformNet(x_sa_ed, net_la['conv2s_2ch'], net_la['conv2s_4ch'])
109
110
111
        # ---------------sample from 3D motion fields
112
        # translate coordinate
113
        v_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], vertex_tpl_0) + aff_sa_inv[:, :3, 3:4]
114
        v_ed = v_ed_o.permute(0, 2, 1) - origin_sa  # [bs, number of vertices,3]
115
        # normalize translated coordinate (image space) to [-1,1]
116
        v_ed_x = (v_ed[:, :, 0:1] - (width / 2)) / (width / 2)
117
        v_ed_y = (v_ed[:, :, 1:2] - (height / 2)) / (height / 2)
118
        v_ed_z = (v_ed[:, :, 2:3] - (depth / 2)) / (depth / 2)
119
        v_ed_norm = torch.cat((v_ed_x, v_ed_y, v_ed_z), 2)
120
        v_ed_norm_expand = v_ed_norm.unsqueeze(1).unsqueeze(1)  # [bs, 1, 1,number of vertices,3]
121
122
        # sample from 3D motion field
123
        pxx = F.grid_sample(net_df['out_def_ed'][:, 0:1], v_ed_norm_expand, align_corners=True).transpose(4, 3)
124
        pyy = F.grid_sample(net_df['out_def_ed'][:, 1:2], v_ed_norm_expand, align_corners=True).transpose(4, 3)
125
        pzz = F.grid_sample(net_df['out_def_ed'][:, 2:3], v_ed_norm_expand, align_corners=True).transpose(4, 3)
126
        # print (pxx.shape, pyy.shape, pzz.shape)
127
        delta_p = torch.cat((pxx, pyy, pzz), 4)
128
        # updata coor (image space)
129
        # print (v_ed.shape, delta_p.shape)
130
        v_0_norm_expand = v_ed_norm_expand + delta_p  # [bs, 1, 1,number of vertices,3]
131
        # t frame
132
        v_0_norm = v_0_norm_expand.squeeze(1).squeeze(1)
133
        v_0_x = v_0_norm[:, :, 0:1] * (width / 2) + (width / 2)
134
        v_0_y = v_0_norm[:, :, 1:2] * (height / 2) + (height / 2)
135
        v_0_z = v_0_norm[:, :, 2:3] * (depth / 2) + (depth / 2)
136
        v_0_crop = torch.cat((v_0_x, v_0_y, v_0_z), 2)
137
        # translate back to mesh space
138
        v_0 = v_0_crop + origin_sa  # [bs, number of vertices,3]
139
        pred_v_0 = torch.matmul(aff_sa[:, :3, :3], v_0.permute(0, 2, 1)) + aff_sa[:, :3,3:4]  # [bs, 3, number of vertices]
140
        # print (pred_vertex_t.shape)
141
142
143
144
        # -------------- differentialable slicer
145
146
        # coordinate transformation np.dot(aff_sa_SR_inv[:3,:3], points_ED.T) + aff_sa_SR_inv[:3,3:4]
147
        v_sa_hat_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], pred_v_0) + aff_sa_inv[:, :3, 3:4]
148
        v_sa_hat_ed = v_sa_hat_ed_o.permute(0, 2, 1) - origin_sa
149
        # print (v_sa_hat_t.shape)
150
        v_2ch_hat_ed_o = torch.matmul(aff_2ch_inv[:, :3, :3], pred_v_0) + aff_2ch_inv[:, :3, 3:4]
151
        v_2ch_hat_ed = v_2ch_hat_ed_o.permute(0, 2, 1) - origin_2ch
152
        v_4ch_hat_ed_o = torch.matmul(aff_4ch_inv[:, :3, :3], pred_v_0) + aff_4ch_inv[:, :3, 3:4]
153
        v_4ch_hat_ed = v_4ch_hat_ed_o.permute(0,2, 1) - origin_4ch
154
155
        # project vertices satisfying threshood
156
        # project to SAX slices, project all vertices to a target plane,
157
        # vertices selection is moved to loss computation function
158
        v_sa_hat_ed_x = torch.clamp(v_sa_hat_ed[:, :, 0:1], min=0, max=height - 1)
159
        v_sa_hat_ed_y = torch.clamp(v_sa_hat_ed[:, :, 1:2], min=0, max=width - 1)
160
        v_sa_hat_ed_cp = torch.cat((v_sa_hat_ed_x, v_sa_hat_ed_y, v_sa_hat_ed[:, :, 2:3]), 2)
161
162
163
        # project to LAX 2CH view
164
        v_2ch_hat_ed_x = torch.clamp(v_2ch_hat_ed[:, :, 0:1], min=0, max=height - 1)
165
        v_2ch_hat_ed_y = torch.clamp(v_2ch_hat_ed[:, :, 1:2], min=0, max=width - 1)
166
        v_2ch_hat_ed_cp = torch.cat((v_2ch_hat_ed_x, v_2ch_hat_ed_y, v_2ch_hat_ed[:, :, 2:3]), 2)
167
168
        v_2ch_idx_ed, w_2ch_ed = projection(v_2ch_hat_ed_cp, 0, temper)
169
170
171
        # project to LAX 4CH view
172
        v_4ch_hat_ed_x = torch.clamp(v_4ch_hat_ed[:, :, 0:1], min=0, max=height - 1)
173
        v_4ch_hat_ed_y = torch.clamp(v_4ch_hat_ed[:, :, 1:2], min=0, max=width - 1)
174
        v_4ch_hat_ed_cp = torch.cat((v_4ch_hat_ed_x, v_4ch_hat_ed_y, v_4ch_hat_ed[:, :, 2:3]), 2)
175
176
        v_4ch_idx_ed, w_4ch_ed = projection(v_4ch_hat_ed_cp, 0, temper)
177
178
179
180
        # --------------------- Segmentation loss------------------
181
        loss_seg_sa_ed = projection_weightHD_loss_SA(v_sa_hat_ed_cp, temper, height, width, depth, mesh2seg_sa_gt, 'train')
182
        loss_seg_2ch_ed = weightedHausdorff_batch(v_2ch_idx_ed, w_2ch_ed, mesh2seg_2ch_gt, height, width, temper, 'train')
183
        loss_seg_4ch_ed = weightedHausdorff_batch(v_4ch_idx_ed, w_4ch_ed, mesh2seg_4ch_gt, height, width, temper, 'train')
184
185
186
        loss_seg = loss_seg_sa_ed + loss_seg_2ch_ed + loss_seg_4ch_ed
187
188
189
        #----------------smoothness loss------------
190
        trg_mesh_ed = Meshes(verts=list(pred_v_0.permute(0, 2, 1)), faces=list(faces_tpl_0))
191
        loss_laplacian_smooth = loss.mesh_laplacian_smoothing(trg_mesh_ed, method='uniform')
192
193
        loss_smooth = loss_laplacian_smooth
194
195
        # ------------------J loss---------------------
196
        loss_huber = huber_loss_3d(net_df['out_def_ed'])
197
198
199
        # ------------------Surface chamfer loss---------------------
200
        loss_surface, _ = loss.chamfer_distance(pred_v_0.permute(0, 2, 1), vertex_0.permute(0, 2, 1))
201
202
203
        loss_all = loss_seg + w_surface * loss_surface + w_smooth * loss_smooth + w_h * loss_huber
204
205
        loss_all.backward()
206
        optimizer.step()
207
208
209
210
        epoch_loss.append(loss_all.item())
211
        epoch_seg_loss.append(loss_seg.item())
212
        epoch_smooth_loss.append(loss_smooth.item())
213
        epoch_surface_loss.append(loss_surface.item())
214
        epoch_huber_loss.append(loss_huber.item())
215
216
217
218
        # tensorboard visulisation
219
        writer.add_scalar("Loss/train", loss_all, epoch * len(training_data_loader) + batch_idx)
220
        writer.add_scalar("Loss/train_seg", loss_seg, epoch * len(training_data_loader) + batch_idx)
221
        writer.add_scalar("Loss/train_smooth", loss_smooth, epoch * len(training_data_loader) + batch_idx)
222
        writer.add_scalar("Loss/train_huber", loss_huber, epoch * len(training_data_loader) + batch_idx)
223
        writer.add_scalar("Loss/train_surface", loss_surface, epoch * len(training_data_loader) + batch_idx)
224
225
226
227
        if batch_idx % 40 == 0:
228
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss all: {:.6f}, '
229
                  'Seg Loss: {:.6f}, Smooth Loss: {:.6f}, Surface Loss: {:.6f}, Huger Loss: {:.6f},'.format(
230
                epoch, batch_idx * len(img_sa_t), len(training_data_loader.dataset),
231
                100. * batch_idx / len(training_data_loader), np.mean(epoch_loss),
232
                np.mean(epoch_seg_loss), np.mean(epoch_smooth_loss), np.mean(epoch_surface_loss), np.mean(epoch_huber_loss)))
233
234
235
def val(epoch):
236
    DeformNet.eval()
237
    MV_LA.eval()
238
239
    val_loss = []
240
    val_seg_loss = []
241
    val_smooth_loss = []
242
    val_surface_loss = []
243
    val_huber_loss = []
244
245
246
    global base_err
247
    for batch_idx, batch in tqdm(enumerate(val_data_loader, 1),
248
                                 total=len(val_data_loader)):
249
250
        img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
251
        vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = batch
252
253
        with torch.no_grad():
254
255
            x_sa_ed = img_sa_ed.type(Tensor)
256
            x_2ch_t = img_2ch_t.type(Tensor)
257
            x_2ch_ed = img_2ch_ed.type(Tensor)
258
            x_4ch_t = img_4ch_t.type(Tensor)
259
            x_4ch_ed = img_4ch_ed.type(Tensor)
260
261
262
            aff_sa_inv = affine_inv[:, 0, :, :].type(Tensor)
263
            aff_sa = affine[:, 0, :, :].type(Tensor)
264
            aff_2ch_inv = affine_inv[:, 1, :, :].type(Tensor)
265
            aff_4ch_inv = affine_inv[:, 2, :, :].type(Tensor)
266
267
268
            origin_sa = origin[:, 0:1, :].type(Tensor)
269
            origin_2ch = origin[:, 1:2, :].type(Tensor)
270
            origin_4ch = origin[:, 2:3, :].type(Tensor)
271
272
            vertex_tpl_0 = vertex_tpl_ed.permute(0, 2, 1).type(Tensor)  # [bs, 3, number of vertices]
273
            faces_tpl_0 = faces_tpl.type(Tensor)  # [bs, number of faces, 3]
274
            vertex_0 = vertex_ed.permute(0, 2, 1).cuda()  # [bs, 3, number of vertices]
275
276
            mesh2seg_sa_gt = Variable(mesh2seg_sa.type(Tensor))
277
            mesh2seg_2ch_gt = Variable(mesh2seg_2ch.type(Tensor))
278
            mesh2seg_4ch_gt = Variable(mesh2seg_4ch.type(Tensor))
279
280
281
            net_la = MV_LA(x_2ch_t, x_2ch_ed, x_4ch_t, x_4ch_ed)
282
            net_df = DeformNet(x_sa_ed, net_la['conv2s_2ch'], net_la['conv2s_4ch'])
283
284
            # ---------------sample from 3D motion fields
285
            # translate coordinate
286
            v_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], vertex_tpl_0) + aff_sa_inv[:, :3, 3:4]
287
            v_ed = v_ed_o.permute(0, 2, 1) - origin_sa  # [bs, number of vertices,3]
288
            # normalize translated coordinate (image space) to [-1,1]
289
            v_ed_x = (v_ed[:, :, 0:1] - (width / 2)) / (width / 2)
290
            v_ed_y = (v_ed[:, :, 1:2] - (height / 2)) / (height / 2)
291
            v_ed_z = (v_ed[:, :, 2:3] - (depth / 2)) / (depth / 2)
292
            v_ed_norm = torch.cat((v_ed_x, v_ed_y, v_ed_z), 2)
293
            v_ed_norm_expand = v_ed_norm.unsqueeze(1).unsqueeze(1)  # [bs, 1, 1,number of vertices,3]
294
295
            # sample from 3D motion field
296
            pxx = F.grid_sample(net_df['out_def_ed'][:, 0:1], v_ed_norm_expand, align_corners=True).transpose(4, 3)
297
            pyy = F.grid_sample(net_df['out_def_ed'][:, 1:2], v_ed_norm_expand, align_corners=True).transpose(4, 3)
298
            pzz = F.grid_sample(net_df['out_def_ed'][:, 2:3], v_ed_norm_expand, align_corners=True).transpose(4, 3)
299
            delta_p = torch.cat((pxx, pyy, pzz), 4)
300
            # updata coor (image space)
301
            # print (v_ed.shape, delta_p.shape)
302
            v_0_norm_expand = v_ed_norm_expand + delta_p  # [bs, 1, 1,number of vertices,3]
303
            # t frame
304
            v_0_norm = v_0_norm_expand.squeeze(1).squeeze(1)
305
            v_0_x = v_0_norm[:, :, 0:1] * (width / 2) + (width / 2)
306
            v_0_y = v_0_norm[:, :, 1:2] * (height / 2) + (height / 2)
307
            v_0_z = v_0_norm[:, :, 2:3] * (depth / 2) + (depth / 2)
308
            v_0_crop = torch.cat((v_0_x, v_0_y, v_0_z), 2)
309
            # translate back to mesh space
310
            v_0 = v_0_crop + origin_sa  # [bs, number of vertices,3]
311
            pred_v_0 = torch.matmul(aff_sa[:, :3, :3], v_0.permute(0, 2, 1)) + aff_sa[:, :3,
312
                                                                               3:4]  # [bs, 3, number of vertices]
313
314
            # -------------- differentialable slicer
315
316
            # coordinate transformation np.dot(aff_sa_SR_inv[:3,:3], points_ED.T) + aff_sa_SR_inv[:3,3:4]
317
            v_sa_hat_ed_o = torch.matmul(aff_sa_inv[:, :3, :3], pred_v_0) + aff_sa_inv[:, :3, 3:4]
318
            v_sa_hat_ed = v_sa_hat_ed_o.permute(0, 2, 1) - origin_sa
319
            # print (v_sa_hat_t.shape)
320
            v_2ch_hat_ed_o = torch.matmul(aff_2ch_inv[:, :3, :3], pred_v_0) + aff_2ch_inv[:, :3, 3:4]
321
            v_2ch_hat_ed = v_2ch_hat_ed_o.permute(0, 2, 1) - origin_2ch
322
            v_4ch_hat_ed_o = torch.matmul(aff_4ch_inv[:, :3, :3], pred_v_0) + aff_4ch_inv[:, :3, 3:4]
323
            v_4ch_hat_ed = v_4ch_hat_ed_o.permute(0, 2, 1) - origin_4ch
324
325
            # project vertices satisfying threshood
326
            # project to SAX slices, project all vertices to a target plane,
327
            # vertices selection is moved to loss computation function
328
            v_sa_hat_ed_x = torch.clamp(v_sa_hat_ed[:, :, 0:1], min=0, max=height - 1)
329
            v_sa_hat_ed_y = torch.clamp(v_sa_hat_ed[:, :, 1:2], min=0, max=width - 1)
330
            v_sa_hat_ed_cp = torch.cat((v_sa_hat_ed_x, v_sa_hat_ed_y, v_sa_hat_ed[:, :, 2:3]), 2)
331
332
333
            # project to LAX 2CH view
334
            v_2ch_hat_ed_x = torch.clamp(v_2ch_hat_ed[:, :, 0:1], min=0, max=height - 1)
335
            v_2ch_hat_ed_y = torch.clamp(v_2ch_hat_ed[:, :, 1:2], min=0, max=width - 1)
336
            v_2ch_hat_ed_cp = torch.cat((v_2ch_hat_ed_x, v_2ch_hat_ed_y, v_2ch_hat_ed[:, :, 2:3]), 2)
337
338
            v_2ch_idx_ed, w_2ch_ed = projection(v_2ch_hat_ed_cp, 0, temper)
339
340
            # project to LAX 4CH view
341
            v_4ch_hat_ed_x = torch.clamp(v_4ch_hat_ed[:, :, 0:1], min=0, max=height - 1)
342
            v_4ch_hat_ed_y = torch.clamp(v_4ch_hat_ed[:, :, 1:2], min=0, max=width - 1)
343
            v_4ch_hat_ed_cp = torch.cat((v_4ch_hat_ed_x, v_4ch_hat_ed_y, v_4ch_hat_ed[:, :, 2:3]), 2)
344
345
            v_4ch_idx_ed, w_4ch_ed = projection(v_4ch_hat_ed_cp, 0, temper)
346
347
            # --------------------- Segmentation loss------------------
348
            loss_seg_sa_ed = projection_weightHD_loss_SA(v_sa_hat_ed_cp, temper, height, width, depth, mesh2seg_sa_gt,
349
                                                         'val')
350
351
            loss_seg_2ch_ed = weightedHausdorff_batch(v_2ch_idx_ed, w_2ch_ed, mesh2seg_2ch_gt, height, width, temper,
352
                                                      'val')
353
            loss_seg_4ch_ed = weightedHausdorff_batch(v_4ch_idx_ed, w_4ch_ed, mesh2seg_4ch_gt, height, width, temper,
354
                                                      'val')
355
356
            loss_seg = loss_seg_sa_ed + loss_seg_2ch_ed + loss_seg_4ch_ed
357
358
            # ----------------smoothness loss------------
359
            # print (pred_vertex_t.permute(0,2,1).shape)
360
            trg_mesh_ed = Meshes(verts=list(pred_v_0.permute(0, 2, 1)), faces=list(faces_tpl_0))
361
            loss_laplacian_smooth = loss.mesh_laplacian_smoothing(trg_mesh_ed, method='uniform')
362
363
            loss_smooth = loss_laplacian_smooth
364
365
            # ------------------J loss---------------------
366
            loss_huber = huber_loss_3d(net_df['out_def_ed'])
367
368
369
            # ------------------Surface chamfer loss---------------------
370
            loss_surface, _ = loss.chamfer_distance(pred_v_0.permute(0, 2, 1), vertex_0.permute(0, 2, 1))
371
372
            loss_all = loss_seg + w_surface * loss_surface + w_smooth * loss_smooth + w_h * loss_huber
373
374
375
            val_loss.append(loss_all.item())
376
            val_seg_loss.append(loss_seg.item())
377
            val_smooth_loss.append(loss_smooth.item())
378
            val_surface_loss.append(loss_surface.item())
379
            val_huber_loss.append(loss_huber.item())
380
381
            if batch_idx == 1:
382
                # tensorboard visulisation
383
                writer.add_scalar("Loss/val", loss_all, epoch * len(training_data_loader) + batch_idx)
384
                writer.add_scalar("Loss/val_seg", loss_seg, epoch * len(training_data_loader) + batch_idx)
385
                writer.add_scalar("Loss/val_smooth", loss_smooth, epoch * len(training_data_loader) + batch_idx)
386
                writer.add_scalar("Loss/val_huber", loss_huber, epoch * len(training_data_loader) + batch_idx)
387
                writer.add_scalar("Loss/val_surface", loss_surface, epoch * len(training_data_loader) + batch_idx)
388
389
390
    if np.mean(val_loss) < base_err:
391
        torch.save(DeformNet.state_dict(), Deform_save_path)
392
        torch.save(MV_LA.state_dict(), Motion_LA_save_path)
393
        base_err = np.mean(val_loss)
394
395
396
397
data_path = '/train_data_path'
398
train_set = TrainDataset(data_path)
399
# loading the data
400
training_data_loader = DataLoader(dataset=train_set, num_workers=n_worker, batch_size=bs, shuffle=True)
401
402
val_data_path = '/val_data_path'
403
val_set = ValDataset(val_data_path)
404
val_data_loader = DataLoader(dataset=val_set, num_workers=n_worker, batch_size=bs, shuffle=False)
405
406
407
for epoch in range(0, n_epoch + 1):
408
    start = time.time()
409
    train(epoch)
410
    end = time.time()
411
    print("training took {:.8f}".format(end-start))
412
413
    print('Epoch {}'.format(epoch))
414
    start = time.time()
415
    val(epoch)
416
    end = time.time()
417
    print("validation took {:.8f}".format(end - start))