Diff of /train.py [000000] .. [390c2f]

Switch to unified view

a b/train.py
1
import argparse
2
import torch
3
torch.cuda.empty_cache() # clearing the occupied cuda memory
4
from torch.backends import cudnn
5
import torch.optim as optim
6
from torch.utils.data import DataLoader
7
import os
8
import numpy as np
9
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
10
11
12
from dataset import LoadDataset
13
from model import InferenceNet, ECGnet
14
from loss import calculate_inference_loss, calculate_reconstruction_loss, calculate_ECG_reconstruction_loss, calculate_classify_loss
15
from utils import lossplot, lossplot_detailed, visualize_PC_with_label, ECG_visual_two, lossplot_classify, visualize_PC_with_twolabel
16
17
def train_ecg(args):
18
    DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
19
    # DEVICE = torch.device('cpu')
20
    train_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='train')
21
    val_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='val')
22
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
23
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
24
    cudnn.benchmark = True
25
26
    network = InferenceNet(in_ch=args.in_ch, out_ch=args.out_ch, num_input=args.num_input, z_dims=args.z_dims)
27
28
    if args.model is not None:
29
        print('Loaded trained model from {}.'.format(args.model))
30
        network.load_state_dict(torch.load(args.model))
31
    else:
32
        print('Begin training new model.')
33
34
    network.to(DEVICE)
35
    optimizer = optim.Adam(network.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
36
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_steps, gamma=args.lr_decay_rate)
37
38
    max_iter = int(len(train_dataset) / args.batch_size + 0.5)
39
    minimum_loss = 1e4
40
    best_epoch = 0
41
42
    lossfile_train = args.log_dir + "/training_loss.txt"
43
    lossfile_val = args.log_dir + "/val_loss.txt"
44
    lossfile_geometry_train = args.log_dir + "/training_calculate_inference_loss.txt"
45
    lossfile_geometry_val = args.log_dir + "/val_calculate_inference_loss.txt"
46
    lossfile_KL_train = args.log_dir + "/training_KL_loss.txt"
47
    lossfile_KL_val = args.log_dir + "/val_KL_loss.txt"
48
    lossfile_ecg_train = args.log_dir + "/training_ecg_loss.txt"
49
    lossfile_ecg_val = args.log_dir + "/val_ecg_loss.txt"
50
51
52
    for epoch in range(1, args.epochs + 1):
53
        if ((epoch % 25) == 0) and (epoch != 0):  
54
            lossplot_classify(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_ecg_train, lossfile_ecg_val)
55
56
        f_train = open(lossfile_train, 'a')  # a: additional writing; w: overwrite writing
57
        f_val = open(lossfile_val, 'a')
58
        f_MI_train = open(lossfile_geometry_train, 'a')  # a: additional writing; w: overwrite writing
59
        f_MI_val = open(lossfile_geometry_val, 'a')
60
        f_KL_train = open(lossfile_KL_train, 'a')  # a: additional writing; w: overwrite writing
61
        f_KL_val = open(lossfile_KL_val, 'a')
62
        f_ecg_train = open(lossfile_ecg_train, 'a')  # a: additional writing; w: overwrite writing
63
        f_ecg_val = open(lossfile_ecg_val, 'a')
64
65
        # if ((epoch % 25) == 0) and (epoch != 0): 
66
        #     if  lamda_KL < 1:
67
        #         lamda_KL = 0.1*epoch*lamda_KL # 0.25
68
        #     else:
69
        #         lamda_KL = 0.1
70
71
        # training
72
        network.train()
73
        total_loss, iter_count = 0, 0
74
        for i, data in enumerate(train_dataloader, 1):
75
            partial_input, ECG_input, gt_MI, partial_input_coarse = data
76
            partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE)      
77
            partial_input_coarse = partial_input_coarse.to(DEVICE)      
78
            partial_input = partial_input.permute(0, 2, 1)
79
80
            optimizer.zero_grad()
81
82
            y_MI, y_ECG, mu, log_var = network(partial_input, ECG_input)
83
       
84
            loss_seg, KL_loss = calculate_classify_loss(y_MI, gt_MI, mu, log_var)
85
            loss_signal = calculate_ECG_reconstruction_loss(y_ECG, ECG_input)
86
            loss = loss_seg + args.lamda_KL*KL_loss
87
88
            check_grad = False
89
            if check_grad:
90
                print(loss_seg)
91
                print(loss_signal)
92
                print(KL_loss)
93
94
                print(loss.requires_grad)
95
                print(loss_seg.requires_grad)
96
                print(KL_loss.requires_grad)
97
                print(loss_signal.requires_grad)
98
99
            visual_check = False
100
            if visual_check:
101
                gd_ECG = ECG_input[0].cpu().detach().numpy()
102
                y_ECG = y_ECG[0].cpu().detach().numpy()
103
                ECG_visual_two(y_ECG, gd_ECG)
104
                
105
            loss.backward()
106
            optimizer.step()
107
108
            f_train.write(str(loss.item()))
109
            f_train.write('\n')
110
            f_MI_train.write(str(loss_seg.item()))
111
            f_MI_train.write('\n')
112
            f_KL_train.write(str(KL_loss.item()))
113
            f_KL_train.write('\n')
114
            f_ecg_train.write(str(loss_signal.item()))
115
            f_ecg_train.write('\n')
116
117
118
            iter_count += 1
119
            total_loss += loss.item()
120
121
            if i % 50 == 0:
122
                print("Training epoch {}/{}, iteration {}/{}: loss is {}".format(epoch, args.epochs, i, max_iter, loss.item()))
123
        scheduler.step()
124
125
        print("\033[96mTraining epoch {}/{}: avg loss = {}\033[0m".format(epoch, args.epochs, total_loss / iter_count))
126
127
        # evaluation
128
        network.eval()
129
        with torch.no_grad():
130
            total_loss, iter_count = 0, 0
131
            for i, data in enumerate(val_dataloader, 1):
132
                partial_input, ECG_input, gt_MI, partial_input_coarse = data
133
                partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE)  
134
                partial_input_coarse = partial_input_coarse.to(DEVICE)  
135
                partial_input = partial_input.permute(0, 2, 1)
136
137
                y_MI, y_ECG, mu, log_var = network(partial_input, ECG_input)
138
        
139
                loss_seg, KL_loss = calculate_classify_loss(y_MI, gt_MI, mu, log_var)
140
                loss_signal = calculate_ECG_reconstruction_loss(y_ECG, ECG_input)
141
                loss = loss_seg + args.lamda_KL*KL_loss
142
143
                total_loss += loss.item()
144
                iter_count += 1
145
146
                visual_check = False
147
                if visual_check:
148
                    gd_ECG = ECG_input[0].cpu().detach().numpy()
149
                    y_ECG = y_ECG[0].cpu().detach().numpy()
150
                    ECG_visual_two(y_ECG, gd_ECG)
151
                    
152
                f_val.write(str(loss.item()))
153
                f_val.write('\n')
154
                f_MI_val.write(str(loss_seg.item()))
155
                f_MI_val.write('\n')
156
                f_KL_val.write(str(KL_loss.item()))
157
                f_KL_val.write('\n')
158
                f_ecg_val.write(str(loss_signal.item()))
159
                f_ecg_val.write('\n')
160
    
161
162
            mean_loss = total_loss / iter_count
163
            print("\033[35mValidation epoch {}/{}, loss is {}\033[0m".format(epoch, args.epochs, mean_loss))
164
165
            # records the best model and epoch
166
            if mean_loss < minimum_loss:
167
                best_epoch = epoch
168
                minimum_loss = mean_loss           
169
                strNetSaveName = 'net_model_classify.pkl'
170
                # strNetSaveName = 'net_with_%d.pkl' % epoch
171
                torch.save(network.state_dict(), args.log_dir + '/' + strNetSaveName)
172
173
        print("\033[4;37mBest model (lowest loss) in epoch {}\033[0m".format(best_epoch))
174
175
    lossplot(lossfile_train, lossfile_val)
176
177
178
def train(args):
179
    DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
180
    # DEVICE = torch.device('cpu')
181
    train_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='train')
182
    val_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='val')
183
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
184
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
185
    cudnn.benchmark = True
186
187
    network = InferenceNet(in_ch=args.in_ch, out_ch=args.out_ch, num_input=args.num_input, z_dims=args.z_dims)
188
189
    if args.model is not None:
190
        print('Loaded trained model from {}.'.format(args.model))
191
        network.load_state_dict(torch.load(args.model))
192
    else:
193
        print('Begin training new model.')
194
195
    network.to(DEVICE)
196
    optimizer = optim.Adam(network.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
197
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_steps, gamma=args.lr_decay_rate)
198
199
    max_iter = int(len(train_dataset) / args.batch_size + 0.5)
200
    minimum_loss = 1e4
201
    best_epoch = 0
202
203
    lossfile_train = args.log_dir + "/training_loss.txt"
204
    lossfile_val = args.log_dir + "/val_loss.txt"
205
    lossfile_geometry_train = args.log_dir + "/training_calculate_inference_loss.txt"
206
    lossfile_geometry_val = args.log_dir + "/val_calculate_inference_loss.txt"
207
    lossfile_compactness_train = args.log_dir + "/training_compactness_loss.txt"
208
    lossfile_compactness_val = args.log_dir + "/val_compactness_loss.txt"
209
    lossfile_KL_train = args.log_dir + "/training_KL_loss.txt"
210
    lossfile_KL_val = args.log_dir + "/val_KL_loss.txt"
211
    lossfile_PC_train = args.log_dir + "/training_PC_loss.txt"
212
    lossfile_PC_val = args.log_dir + "/val_PC_loss.txt"
213
    lossfile_ecg_train = args.log_dir + "/training_ecg_loss.txt"
214
    lossfile_ecg_val = args.log_dir + "/val_ecg_loss.txt"
215
    lossfile_RVp_train = args.log_dir + "/training_RVp_loss.txt"
216
    lossfile_RVp_val = args.log_dir + "/val_RVp_loss.txt"
217
    lossfile_size_train = args.log_dir + "/training_MIsize_loss.txt"
218
    lossfile_size_val = args.log_dir + "/val_MIsize_loss.txt"
219
220
    lamda_KL = args.lamda_KL
221
    for epoch in range(1, args.epochs + 1):
222
        if ((epoch % 25) == 0) and (epoch != 0):  
223
            lossplot_detailed(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val)
224
225
        f_train = open(lossfile_train, 'a')  # a: additional writing; w: overwrite writing
226
        f_val = open(lossfile_val, 'a')
227
        f_MI_train = open(lossfile_geometry_train, 'a')  # a: additional writing; w: overwrite writing
228
        f_MI_val = open(lossfile_geometry_val, 'a')
229
        f_compactness_train = open(lossfile_compactness_train, 'a')  # a: additional writing; w: overwrite writing
230
        f_compactness_val = open(lossfile_compactness_val, 'a')
231
        f_KL_train = open(lossfile_KL_train, 'a')  # a: additional writing; w: overwrite writing
232
        f_KL_val = open(lossfile_KL_val, 'a')
233
        f_PC_train = open(lossfile_PC_train, 'a')  # a: additional writing; w: overwrite writing
234
        f_PC_val = open(lossfile_PC_val, 'a')
235
        f_ecg_train = open(lossfile_ecg_train, 'a')  # a: additional writing; w: overwrite writing
236
        f_ecg_val = open(lossfile_ecg_val, 'a')
237
        f_size_train = open(lossfile_size_train, 'a')  # a: additional writing; w: overwrite writing
238
        f_size_val = open(lossfile_size_val, 'a')
239
        f_RVp_train = open(lossfile_RVp_train, 'a')  # a: additional writing; w: overwrite writing
240
        f_RVp_val = open(lossfile_RVp_val, 'a')
241
242
        # if epoch != 0: 
243
        #     if  lamda_KL < 1:
244
        #         lamda_KL = 0.1*epoch*args.lamda_KL 
245
        #     else:
246
        #         lamda_KL = 0.1
247
248
        # training
249
        network.train()
250
        total_loss, iter_count = 0, 0
251
        for i, data in enumerate(train_dataloader, 1):
252
            partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type = data
253
            partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE)      
254
            partial_input_coarse = partial_input_coarse.to(DEVICE)      
255
            partial_input = partial_input.permute(0, 2, 1)
256
257
            optimizer.zero_grad()
258
259
            y_MI, y_coarse, y_detail, y_ECG, mu, log_var = network(partial_input[:, 0:7, :], ECG_input)
260
       
261
            loss_seg, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss = calculate_inference_loss(y_MI, gt_MI, mu, log_var, partial_input)
262
            loss_geo, loss_signal = calculate_reconstruction_loss(y_coarse, y_detail, partial_input_coarse, partial_input, y_ECG, ECG_input)
263
            loss = loss_seg + args.lamda_compact*loss_compactness + args.lamda_RVp*loss_MI_RVpenalty + args.lamda_MIsize*loss_MI_size + args.lamda_KL*KL_loss + args.lamda_recon*loss_geo # + args.lamda_recon*loss_signal # 
264
265
            check_grad = False
266
            if check_grad:
267
                print(loss.requires_grad)
268
                print(loss_seg.requires_grad)
269
                print(loss_compactness.requires_grad)
270
                print(loss_MI_RVpenalty.requires_grad)
271
                print(KL_loss.requires_grad)
272
                print(loss_MI_size.requires_grad)
273
                print(loss_geo.requires_grad)
274
                print(loss_signal.requires_grad)
275
276
            visual_check = False
277
            if visual_check:
278
                y_predict = y_MI[0].cpu().detach().numpy()
279
                y_gd = gt_MI[0].cpu().detach().numpy()
280
                x_input = partial_input[0].cpu().detach().numpy()
281
                y_predict_argmax = np.argmax(y_predict, axis=0)
282
                visualize_PC_with_twolabel(x_input[0:3, 0:args.num_input].transpose(), y_predict_argmax, y_gd, filename='RNmap_gd_pre.jpg')
283
284
            loss.backward()
285
            optimizer.step()
286
287
            f_train.write(str(loss.item()))
288
            f_train.write('\n')
289
            f_MI_train.write(str(loss_seg.item()))
290
            f_MI_train.write('\n')
291
            f_compactness_train.write(str(loss_compactness.item()))
292
            f_compactness_train.write('\n')
293
            f_KL_train.write(str(KL_loss.item()))
294
            f_KL_train.write('\n')
295
            f_PC_train.write(str(loss_geo.item()))
296
            f_PC_train.write('\n')
297
            f_ecg_train.write(str(loss_signal.item()))
298
            f_ecg_train.write('\n')
299
            f_size_train.write(str((loss_MI_size.item())))
300
            f_size_train.write('\n')
301
            f_RVp_train.write(str(loss_MI_RVpenalty.item()))
302
            f_RVp_train.write('\n')
303
304
            iter_count += 1
305
            total_loss += loss.item()
306
307
            if i % 50 == 0:
308
                print("Training epoch {}/{}, iteration {}/{}: loss is {}".format(epoch, args.epochs, i, max_iter, loss.item()))
309
        scheduler.step()
310
311
        print("\033[96mTraining epoch {}/{}: avg loss = {}\033[0m".format(epoch, args.epochs, total_loss / iter_count))
312
313
        # evaluation
314
        network.eval()
315
        with torch.no_grad():
316
            total_loss, iter_count = 0, 0
317
            for i, data in enumerate(val_dataloader, 1):
318
                partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type = data
319
                partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE)  
320
                partial_input_coarse = partial_input_coarse.to(DEVICE)  
321
                partial_input = partial_input.permute(0, 2, 1)
322
323
                y_MI, y_coarse, y_detail, y_ECG, mu, log_var = network(partial_input[:, 0:7, :], ECG_input)
324
325
                loss_seg, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss = calculate_inference_loss(y_MI, gt_MI, mu, log_var, partial_input)
326
                loss_geo, loss_signal = calculate_reconstruction_loss(y_coarse, y_detail, partial_input_coarse, partial_input, y_ECG, ECG_input)
327
                loss = loss_seg + args.lamda_compact*loss_compactness + args.lamda_RVp*loss_MI_RVpenalty + args.lamda_MIsize*loss_MI_size + args.lamda_KL*KL_loss + args.lamda_recon*loss_geo # + args.lamda_recon*loss_signal # 
328
329
                total_loss += loss.item()
330
                iter_count += 1
331
332
                if ((epoch % 25) == 0) and (epoch != 0) and (i == 1):  
333
                    y_predict = y_MI[0].cpu().detach().numpy()
334
                    y_gd = gt_MI[0].cpu().detach().numpy()
335
                    x_input = partial_input[0].cpu().detach().numpy()
336
                    y_predict_argmax = np.argmax(y_predict, axis=0)
337
                    visualize_PC_with_twolabel(x_input[0:3, 0:args.num_input].transpose(), y_predict_argmax, y_gd, filename='RNmap_gd_pre.jpg')
338
                    
339
                f_val.write(str(loss.item()))
340
                f_val.write('\n')
341
                f_MI_val.write(str(loss_seg.item()))
342
                f_MI_val.write('\n')
343
                f_compactness_val.write(str(loss_compactness.item()))
344
                f_compactness_val.write('\n')
345
                f_KL_val.write(str(KL_loss.item()))
346
                f_KL_val.write('\n')
347
                f_PC_val.write(str(loss_geo.item()))
348
                f_PC_val.write('\n')
349
                f_ecg_val.write(str(loss_signal.item()))
350
                f_ecg_val.write('\n')
351
                f_size_val.write(str(loss_MI_size.item()))
352
                f_size_val.write('\n')
353
                f_RVp_val.write(str(loss_MI_RVpenalty.item()))
354
                f_RVp_val.write('\n')
355
356
            mean_loss = total_loss / iter_count
357
            print("\033[35mValidation epoch {}/{}, loss is {}\033[0m".format(epoch, args.epochs, mean_loss))
358
359
            # records the best model and epoch
360
            if mean_loss < minimum_loss:
361
                best_epoch = epoch
362
                minimum_loss = mean_loss           
363
                strNetSaveName = 'net_model.pkl'
364
                # strNetSaveName = 'net_with_%d.pkl' % epoch
365
                torch.save(network.state_dict(), args.log_dir + '/' + strNetSaveName)
366
367
        print("\033[4;37mBest model (lowest loss) in epoch {}\033[0m".format(best_epoch))
368
369
    lossplot(lossfile_train, lossfile_val)
370
371
if __name__ == "__main__":
372
    parser = argparse.ArgumentParser()
373
    parser.add_argument('--partial_root', type=str, default='./Big_data_inference/meta_data/UKB_clinical_data/')
374
    parser.add_argument('--model', type=str, default=None) #'log/net_model.pkl'
375
    parser.add_argument('--in_ch', type=int, default=3+4) # coordinate dimension + label index
376
    parser.add_argument('--out_ch', type=int, default=3) # 3scar, BZ, normal/ 18 for ecg-based classification
377
    parser.add_argument('--z_dims', type=int, default=16)
378
    parser.add_argument('--num_input', type=int, default=1024*4)
379
    parser.add_argument('--batch_size', type=int, default=4) # 4
380
    parser.add_argument('--lamda_recon', type=float, default=1) # 1
381
    parser.add_argument('--lamda_KL', type=float, default=1e-2) # 1e-2
382
    parser.add_argument('--lamda_MIsize', type=float, default=1) # 1
383
    parser.add_argument('--lamda_RVp', type=float, default=1) # 1 
384
    parser.add_argument('--lamda_compact', type=float, default=1) # 1
385
    parser.add_argument('--base_lr', type=float, default=1e-4) #1e-4
386
    parser.add_argument('--lr_decay_steps', type=int, default=50) 
387
    parser.add_argument('--lr_decay_rate', type=float, default=0.5) 
388
    parser.add_argument('--weight_decay', type=float, default=1e-3) #1e-3
389
    parser.add_argument('--epochs', type=int, default=500)
390
    parser.add_argument('--num_workers', type=int, default=1)
391
    parser.add_argument('--log_dir', type=str, default='log')
392
    args = parser.parse_args()
393
394
    train(args)