--- a +++ b/train.py @@ -0,0 +1,394 @@ +import argparse +import torch +torch.cuda.empty_cache() # clearing the occupied cuda memory +from torch.backends import cudnn +import torch.optim as optim +from torch.utils.data import DataLoader +import os +import numpy as np +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" + + +from dataset import LoadDataset +from model import InferenceNet, ECGnet +from loss import calculate_inference_loss, calculate_reconstruction_loss, calculate_ECG_reconstruction_loss, calculate_classify_loss +from utils import lossplot, lossplot_detailed, visualize_PC_with_label, ECG_visual_two, lossplot_classify, visualize_PC_with_twolabel + +def train_ecg(args): + DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + # DEVICE = torch.device('cpu') + train_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='train') + val_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='val') + train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) + val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) + cudnn.benchmark = True + + network = InferenceNet(in_ch=args.in_ch, out_ch=args.out_ch, num_input=args.num_input, z_dims=args.z_dims) + + if args.model is not None: + print('Loaded trained model from {}.'.format(args.model)) + network.load_state_dict(torch.load(args.model)) + else: + print('Begin training new model.') + + network.to(DEVICE) + optimizer = optim.Adam(network.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_steps, gamma=args.lr_decay_rate) + + max_iter = int(len(train_dataset) / args.batch_size + 0.5) + minimum_loss = 1e4 + best_epoch = 0 + + lossfile_train = args.log_dir + "/training_loss.txt" + lossfile_val = args.log_dir + "/val_loss.txt" + lossfile_geometry_train = args.log_dir + "/training_calculate_inference_loss.txt" + lossfile_geometry_val = args.log_dir + "/val_calculate_inference_loss.txt" + lossfile_KL_train = args.log_dir + "/training_KL_loss.txt" + lossfile_KL_val = args.log_dir + "/val_KL_loss.txt" + lossfile_ecg_train = args.log_dir + "/training_ecg_loss.txt" + lossfile_ecg_val = args.log_dir + "/val_ecg_loss.txt" + + + for epoch in range(1, args.epochs + 1): + if ((epoch % 25) == 0) and (epoch != 0): + lossplot_classify(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_ecg_train, lossfile_ecg_val) + + f_train = open(lossfile_train, 'a') # a: additional writing; w: overwrite writing + f_val = open(lossfile_val, 'a') + f_MI_train = open(lossfile_geometry_train, 'a') # a: additional writing; w: overwrite writing + f_MI_val = open(lossfile_geometry_val, 'a') + f_KL_train = open(lossfile_KL_train, 'a') # a: additional writing; w: overwrite writing + f_KL_val = open(lossfile_KL_val, 'a') + f_ecg_train = open(lossfile_ecg_train, 'a') # a: additional writing; w: overwrite writing + f_ecg_val = open(lossfile_ecg_val, 'a') + + # if ((epoch % 25) == 0) and (epoch != 0): + # if lamda_KL < 1: + # lamda_KL = 0.1*epoch*lamda_KL # 0.25 + # else: + # lamda_KL = 0.1 + + # training + network.train() + total_loss, iter_count = 0, 0 + for i, data in enumerate(train_dataloader, 1): + partial_input, ECG_input, gt_MI, partial_input_coarse = data + partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) + partial_input_coarse = partial_input_coarse.to(DEVICE) + partial_input = partial_input.permute(0, 2, 1) + + optimizer.zero_grad() + + y_MI, y_ECG, mu, log_var = network(partial_input, ECG_input) + + loss_seg, KL_loss = calculate_classify_loss(y_MI, gt_MI, mu, log_var) + loss_signal = calculate_ECG_reconstruction_loss(y_ECG, ECG_input) + loss = loss_seg + args.lamda_KL*KL_loss + + check_grad = False + if check_grad: + print(loss_seg) + print(loss_signal) + print(KL_loss) + + print(loss.requires_grad) + print(loss_seg.requires_grad) + print(KL_loss.requires_grad) + print(loss_signal.requires_grad) + + visual_check = False + if visual_check: + gd_ECG = ECG_input[0].cpu().detach().numpy() + y_ECG = y_ECG[0].cpu().detach().numpy() + ECG_visual_two(y_ECG, gd_ECG) + + loss.backward() + optimizer.step() + + f_train.write(str(loss.item())) + f_train.write('\n') + f_MI_train.write(str(loss_seg.item())) + f_MI_train.write('\n') + f_KL_train.write(str(KL_loss.item())) + f_KL_train.write('\n') + f_ecg_train.write(str(loss_signal.item())) + f_ecg_train.write('\n') + + + iter_count += 1 + total_loss += loss.item() + + if i % 50 == 0: + print("Training epoch {}/{}, iteration {}/{}: loss is {}".format(epoch, args.epochs, i, max_iter, loss.item())) + scheduler.step() + + print("\033[96mTraining epoch {}/{}: avg loss = {}\033[0m".format(epoch, args.epochs, total_loss / iter_count)) + + # evaluation + network.eval() + with torch.no_grad(): + total_loss, iter_count = 0, 0 + for i, data in enumerate(val_dataloader, 1): + partial_input, ECG_input, gt_MI, partial_input_coarse = data + partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) + partial_input_coarse = partial_input_coarse.to(DEVICE) + partial_input = partial_input.permute(0, 2, 1) + + y_MI, y_ECG, mu, log_var = network(partial_input, ECG_input) + + loss_seg, KL_loss = calculate_classify_loss(y_MI, gt_MI, mu, log_var) + loss_signal = calculate_ECG_reconstruction_loss(y_ECG, ECG_input) + loss = loss_seg + args.lamda_KL*KL_loss + + total_loss += loss.item() + iter_count += 1 + + visual_check = False + if visual_check: + gd_ECG = ECG_input[0].cpu().detach().numpy() + y_ECG = y_ECG[0].cpu().detach().numpy() + ECG_visual_two(y_ECG, gd_ECG) + + f_val.write(str(loss.item())) + f_val.write('\n') + f_MI_val.write(str(loss_seg.item())) + f_MI_val.write('\n') + f_KL_val.write(str(KL_loss.item())) + f_KL_val.write('\n') + f_ecg_val.write(str(loss_signal.item())) + f_ecg_val.write('\n') + + + mean_loss = total_loss / iter_count + print("\033[35mValidation epoch {}/{}, loss is {}\033[0m".format(epoch, args.epochs, mean_loss)) + + # records the best model and epoch + if mean_loss < minimum_loss: + best_epoch = epoch + minimum_loss = mean_loss + strNetSaveName = 'net_model_classify.pkl' + # strNetSaveName = 'net_with_%d.pkl' % epoch + torch.save(network.state_dict(), args.log_dir + '/' + strNetSaveName) + + print("\033[4;37mBest model (lowest loss) in epoch {}\033[0m".format(best_epoch)) + + lossplot(lossfile_train, lossfile_val) + + +def train(args): + DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + # DEVICE = torch.device('cpu') + train_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='train') + val_dataset = LoadDataset(path=args.partial_root, num_input=args.num_input, split='val') + train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) + val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) + cudnn.benchmark = True + + network = InferenceNet(in_ch=args.in_ch, out_ch=args.out_ch, num_input=args.num_input, z_dims=args.z_dims) + + if args.model is not None: + print('Loaded trained model from {}.'.format(args.model)) + network.load_state_dict(torch.load(args.model)) + else: + print('Begin training new model.') + + network.to(DEVICE) + optimizer = optim.Adam(network.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_steps, gamma=args.lr_decay_rate) + + max_iter = int(len(train_dataset) / args.batch_size + 0.5) + minimum_loss = 1e4 + best_epoch = 0 + + lossfile_train = args.log_dir + "/training_loss.txt" + lossfile_val = args.log_dir + "/val_loss.txt" + lossfile_geometry_train = args.log_dir + "/training_calculate_inference_loss.txt" + lossfile_geometry_val = args.log_dir + "/val_calculate_inference_loss.txt" + lossfile_compactness_train = args.log_dir + "/training_compactness_loss.txt" + lossfile_compactness_val = args.log_dir + "/val_compactness_loss.txt" + lossfile_KL_train = args.log_dir + "/training_KL_loss.txt" + lossfile_KL_val = args.log_dir + "/val_KL_loss.txt" + lossfile_PC_train = args.log_dir + "/training_PC_loss.txt" + lossfile_PC_val = args.log_dir + "/val_PC_loss.txt" + lossfile_ecg_train = args.log_dir + "/training_ecg_loss.txt" + lossfile_ecg_val = args.log_dir + "/val_ecg_loss.txt" + lossfile_RVp_train = args.log_dir + "/training_RVp_loss.txt" + lossfile_RVp_val = args.log_dir + "/val_RVp_loss.txt" + lossfile_size_train = args.log_dir + "/training_MIsize_loss.txt" + lossfile_size_val = args.log_dir + "/val_MIsize_loss.txt" + + lamda_KL = args.lamda_KL + for epoch in range(1, args.epochs + 1): + if ((epoch % 25) == 0) and (epoch != 0): + lossplot_detailed(lossfile_train, lossfile_val, lossfile_geometry_train, lossfile_geometry_val, lossfile_KL_train, lossfile_KL_val, lossfile_compactness_train, lossfile_compactness_val, lossfile_PC_train, lossfile_PC_val, lossfile_ecg_train, lossfile_ecg_val, lossfile_RVp_train, lossfile_RVp_val, lossfile_size_train, lossfile_size_val) + + f_train = open(lossfile_train, 'a') # a: additional writing; w: overwrite writing + f_val = open(lossfile_val, 'a') + f_MI_train = open(lossfile_geometry_train, 'a') # a: additional writing; w: overwrite writing + f_MI_val = open(lossfile_geometry_val, 'a') + f_compactness_train = open(lossfile_compactness_train, 'a') # a: additional writing; w: overwrite writing + f_compactness_val = open(lossfile_compactness_val, 'a') + f_KL_train = open(lossfile_KL_train, 'a') # a: additional writing; w: overwrite writing + f_KL_val = open(lossfile_KL_val, 'a') + f_PC_train = open(lossfile_PC_train, 'a') # a: additional writing; w: overwrite writing + f_PC_val = open(lossfile_PC_val, 'a') + f_ecg_train = open(lossfile_ecg_train, 'a') # a: additional writing; w: overwrite writing + f_ecg_val = open(lossfile_ecg_val, 'a') + f_size_train = open(lossfile_size_train, 'a') # a: additional writing; w: overwrite writing + f_size_val = open(lossfile_size_val, 'a') + f_RVp_train = open(lossfile_RVp_train, 'a') # a: additional writing; w: overwrite writing + f_RVp_val = open(lossfile_RVp_val, 'a') + + # if epoch != 0: + # if lamda_KL < 1: + # lamda_KL = 0.1*epoch*args.lamda_KL + # else: + # lamda_KL = 0.1 + + # training + network.train() + total_loss, iter_count = 0, 0 + for i, data in enumerate(train_dataloader, 1): + partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type = data + partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) + partial_input_coarse = partial_input_coarse.to(DEVICE) + partial_input = partial_input.permute(0, 2, 1) + + optimizer.zero_grad() + + y_MI, y_coarse, y_detail, y_ECG, mu, log_var = network(partial_input[:, 0:7, :], ECG_input) + + loss_seg, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss = calculate_inference_loss(y_MI, gt_MI, mu, log_var, partial_input) + loss_geo, loss_signal = calculate_reconstruction_loss(y_coarse, y_detail, partial_input_coarse, partial_input, y_ECG, ECG_input) + loss = loss_seg + args.lamda_compact*loss_compactness + args.lamda_RVp*loss_MI_RVpenalty + args.lamda_MIsize*loss_MI_size + args.lamda_KL*KL_loss + args.lamda_recon*loss_geo # + args.lamda_recon*loss_signal # + + check_grad = False + if check_grad: + print(loss.requires_grad) + print(loss_seg.requires_grad) + print(loss_compactness.requires_grad) + print(loss_MI_RVpenalty.requires_grad) + print(KL_loss.requires_grad) + print(loss_MI_size.requires_grad) + print(loss_geo.requires_grad) + print(loss_signal.requires_grad) + + visual_check = False + if visual_check: + y_predict = y_MI[0].cpu().detach().numpy() + y_gd = gt_MI[0].cpu().detach().numpy() + x_input = partial_input[0].cpu().detach().numpy() + y_predict_argmax = np.argmax(y_predict, axis=0) + visualize_PC_with_twolabel(x_input[0:3, 0:args.num_input].transpose(), y_predict_argmax, y_gd, filename='RNmap_gd_pre.jpg') + + loss.backward() + optimizer.step() + + f_train.write(str(loss.item())) + f_train.write('\n') + f_MI_train.write(str(loss_seg.item())) + f_MI_train.write('\n') + f_compactness_train.write(str(loss_compactness.item())) + f_compactness_train.write('\n') + f_KL_train.write(str(KL_loss.item())) + f_KL_train.write('\n') + f_PC_train.write(str(loss_geo.item())) + f_PC_train.write('\n') + f_ecg_train.write(str(loss_signal.item())) + f_ecg_train.write('\n') + f_size_train.write(str((loss_MI_size.item()))) + f_size_train.write('\n') + f_RVp_train.write(str(loss_MI_RVpenalty.item())) + f_RVp_train.write('\n') + + iter_count += 1 + total_loss += loss.item() + + if i % 50 == 0: + print("Training epoch {}/{}, iteration {}/{}: loss is {}".format(epoch, args.epochs, i, max_iter, loss.item())) + scheduler.step() + + print("\033[96mTraining epoch {}/{}: avg loss = {}\033[0m".format(epoch, args.epochs, total_loss / iter_count)) + + # evaluation + network.eval() + with torch.no_grad(): + total_loss, iter_count = 0, 0 + for i, data in enumerate(val_dataloader, 1): + partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type = data + partial_input, ECG_input, gt_MI = partial_input.to(DEVICE), ECG_input.to(DEVICE), gt_MI.to(DEVICE) + partial_input_coarse = partial_input_coarse.to(DEVICE) + partial_input = partial_input.permute(0, 2, 1) + + y_MI, y_coarse, y_detail, y_ECG, mu, log_var = network(partial_input[:, 0:7, :], ECG_input) + + loss_seg, loss_compactness, loss_MI_RVpenalty, loss_MI_size, KL_loss = calculate_inference_loss(y_MI, gt_MI, mu, log_var, partial_input) + loss_geo, loss_signal = calculate_reconstruction_loss(y_coarse, y_detail, partial_input_coarse, partial_input, y_ECG, ECG_input) + loss = loss_seg + args.lamda_compact*loss_compactness + args.lamda_RVp*loss_MI_RVpenalty + args.lamda_MIsize*loss_MI_size + args.lamda_KL*KL_loss + args.lamda_recon*loss_geo # + args.lamda_recon*loss_signal # + + total_loss += loss.item() + iter_count += 1 + + if ((epoch % 25) == 0) and (epoch != 0) and (i == 1): + y_predict = y_MI[0].cpu().detach().numpy() + y_gd = gt_MI[0].cpu().detach().numpy() + x_input = partial_input[0].cpu().detach().numpy() + y_predict_argmax = np.argmax(y_predict, axis=0) + visualize_PC_with_twolabel(x_input[0:3, 0:args.num_input].transpose(), y_predict_argmax, y_gd, filename='RNmap_gd_pre.jpg') + + f_val.write(str(loss.item())) + f_val.write('\n') + f_MI_val.write(str(loss_seg.item())) + f_MI_val.write('\n') + f_compactness_val.write(str(loss_compactness.item())) + f_compactness_val.write('\n') + f_KL_val.write(str(KL_loss.item())) + f_KL_val.write('\n') + f_PC_val.write(str(loss_geo.item())) + f_PC_val.write('\n') + f_ecg_val.write(str(loss_signal.item())) + f_ecg_val.write('\n') + f_size_val.write(str(loss_MI_size.item())) + f_size_val.write('\n') + f_RVp_val.write(str(loss_MI_RVpenalty.item())) + f_RVp_val.write('\n') + + mean_loss = total_loss / iter_count + print("\033[35mValidation epoch {}/{}, loss is {}\033[0m".format(epoch, args.epochs, mean_loss)) + + # records the best model and epoch + if mean_loss < minimum_loss: + best_epoch = epoch + minimum_loss = mean_loss + strNetSaveName = 'net_model.pkl' + # strNetSaveName = 'net_with_%d.pkl' % epoch + torch.save(network.state_dict(), args.log_dir + '/' + strNetSaveName) + + print("\033[4;37mBest model (lowest loss) in epoch {}\033[0m".format(best_epoch)) + + lossplot(lossfile_train, lossfile_val) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--partial_root', type=str, default='./Big_data_inference/meta_data/UKB_clinical_data/') + parser.add_argument('--model', type=str, default=None) #'log/net_model.pkl' + parser.add_argument('--in_ch', type=int, default=3+4) # coordinate dimension + label index + parser.add_argument('--out_ch', type=int, default=3) # 3scar, BZ, normal/ 18 for ecg-based classification + parser.add_argument('--z_dims', type=int, default=16) + parser.add_argument('--num_input', type=int, default=1024*4) + parser.add_argument('--batch_size', type=int, default=4) # 4 + parser.add_argument('--lamda_recon', type=float, default=1) # 1 + parser.add_argument('--lamda_KL', type=float, default=1e-2) # 1e-2 + parser.add_argument('--lamda_MIsize', type=float, default=1) # 1 + parser.add_argument('--lamda_RVp', type=float, default=1) # 1 + parser.add_argument('--lamda_compact', type=float, default=1) # 1 + parser.add_argument('--base_lr', type=float, default=1e-4) #1e-4 + parser.add_argument('--lr_decay_steps', type=int, default=50) + parser.add_argument('--lr_decay_rate', type=float, default=0.5) + parser.add_argument('--weight_decay', type=float, default=1e-3) #1e-3 + parser.add_argument('--epochs', type=int, default=500) + parser.add_argument('--num_workers', type=int, default=1) + parser.add_argument('--log_dir', type=str, default='log') + args = parser.parse_args() + + train(args) \ No newline at end of file