--- a +++ b/trainner.py @@ -0,0 +1,218 @@ +from model.MyModel import Mymodel +from dataprocess import * +import loss.losses as losses +from metrics import * +import torch.optim as optim +import time +import numpy as np +import os +import torch +from config import Config +import shutil +from tqdm import tqdm +import imageio +import math +from bisect import bisect_right + +config = Config() + +torch.cuda.set_device(config.gpu) + +model_name = config.arch +if not os.path.isdir('result'): + os.mkdir('result') +if config.resume is False: + with open('./result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f: + f.seek(0) + f.truncate() +model = Mymodel(img_ch=3) +model.cuda() +best_dice = 0 # best test accuracy +start_epoch = 0 # start from epoch 0 or last checkpoint epoch +optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, betas=(0.5, 0.999)) +dataloader, dataloader_val = get_dataloader(config, batchsize=config.batch_size) # 64 +criterion = losses.init_loss('BCE_logit').cuda() + +if config.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + if config.evaluate: + checkpoint = torch.load('./checkpoint/' + str(model_name) + '_best.pth.tar') + else: + checkpoint = torch.load('./checkpoint/' + str(model_name) + '.pth.tar') + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + best_dice = checkpoint['dice'] + start_epoch = config.epochs + +def adjust_lr(optimizer, epoch, eta_max=0.0001, eta_min=0.): + cur_lr = 0. + if config.lr_type == 'SGDR': + i = int(math.log2(epoch / config.sgdr_t + 1)) + T_cur = epoch - config.sgdr_t * (2 ** (i) - 1) + T_i = (config.sgdr_t * 2 ** i) + + cur_lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * T_cur / T_i)) + + elif config.lr_type == 'multistep': + cur_lr = config.learning_rate * 0.1 ** bisect_right(config.milestones, epoch) + + for param_group in optimizer.param_groups: + param_group['lr'] = cur_lr + return cur_lr + +def train(epoch): + model.train() + train_loss = 0 + + start_time = time.time() + lr = adjust_lr(optimizer, epoch) + for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader): + iter_start_time = time.time() + inputs = inputs.cuda() + lungs = lungs.cuda() + medias = medias.cuda() + targets_i = targets_i.cuda() + targets_u = targets_u.cuda() + targets_s = targets_s.cuda() + + outputs = model(inputs) + + outputs_i_sig = torch.sigmoid(outputs[0]) + outputs_u_sig = torch.sigmoid(outputs[1]) + outputs_s_sig = torch.sigmoid(outputs[2]) + outputs_final_sig = torch.sigmoid(outputs[3]) + + loss_seg_i = criterion(outputs_i_sig, targets_i) + loss_seg_u = criterion(outputs_u_sig, targets_u) + loss_seg_s = criterion(outputs_s_sig, targets_s) + loss_seg_final = criterion(outputs_final_sig, targets_s) + + loss_all = config.weight_seg1 * loss_seg_final + config.weight_seg2 * (loss_seg_i + loss_seg_u + loss_seg_s) + + optimizer.zero_grad() + loss_all.backward() + optimizer.step() + + train_loss += loss_all.item() + + print('Epoch:{}\t batch_idx:{}/All_batch:{}\t duration:{:.3f}\t loss_all:{:.3f}' + .format(epoch, batch_idx, len(dataloader), time.time()-iter_start_time, loss_all.item())) + iter_start_time = time.time() + print('Epoch:{0}\t duration:{1:.3f}\ttrain_loss:{2:.6f}'.format(epoch, time.time()-start_time, train_loss/len(dataloader))) + + with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f: + f.write('Epoch:{0}\t duration:{1:.3f}\t learning_rate:{2:.6f}\t train_loss:{3:.4f}' + .format(epoch, time.time()-start_time, lr, train_loss/len(dataloader))) + +def test(epoch): + global best_dice + model.eval() + dices_all_i = [] + dices_all_u = [] + dices_all_s = [] + ious_all_i = [] + ious_all_u = [] + ious_all_s = [] + nsds_all_i = [] + nsds_all_u = [] + nsds_all_s = [] + with torch.no_grad(): + for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader_val): + inputs = inputs.cuda() + lungs = lungs.cuda() + medias = medias.cuda() + targets_i = targets_i.cuda() + targets_u = targets_u.cuda() + targets_s = targets_s.cuda() + + outputs = model(inputs) + + outputs_final_sig = torch.sigmoid(outputs[3]) + + dices_all_i = meandice(outputs_final_sig, targets_i, dices_all_i) + dices_all_u = meandice(outputs_final_sig, targets_u, dices_all_u) + dices_all_s = meandice(outputs_final_sig, targets_s, dices_all_s) + ious_all_i = meandIoU(outputs_final_sig, targets_i, ious_all_i) + ious_all_u = meandIoU(outputs_final_sig, targets_u, ious_all_u) + ious_all_s = meandIoU(outputs_final_sig, targets_s, ious_all_s) + nsds_all_i = meanNSD(outputs_final_sig, targets_i, nsds_all_i) + nsds_all_u = meanNSD(outputs_final_sig, targets_u, nsds_all_u) + nsds_all_s = meanNSD(outputs_final_sig, targets_s, nsds_all_s) + + + # 保存图像 + # if config.evaluate: + # basePath = os.path.join(config.figurePath, str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '/fold' + str(config.fold)) + # inputsPath = os.path.join(basePath, 'inputs') + # masksPath_s = os.path.join(basePath, 'masks_s') + # outputsPath_s = os.path.join(basePath, 'outputs_s') + # uncertaintyPath = os.path.join(basePath, 'uncertainty_map') + # if not os.path.exists(inputsPath): + # os.makedirs(inputsPath) + # if not os.path.exists(masksPath_s): + # os.makedirs(masksPath_s) + # if not os.path.exists(outputsPath_s): + # os.makedirs(outputsPath_s) + # if not os.path.exists(uncertaintyPath): + # os.makedirs(uncertaintyPath) + # num = inputs.shape[0] + # inputsfolder = inputs.chunk(num, dim=0) + # masksfolder_s = targets_s_sig.chunk(num, dim=0) + # outputsfolder_s = outputs_final_sig.chunk(num, dim=0) + # uncertaintyfolder = uncertainty_map.chunk(num, dim=0) + # for index in range(num): + # input = inputsfolder[index] + # input = input.squeeze() + # imageio.imsave(os.path.join(inputsPath, str(epoch) + '_' + str(batch_idx*config.batch_size+index+1) + '.jpg'), input.cpu().detach().numpy()) + + # target_s = masksfolder_s[index] + # target_s = target_s.squeeze() + # imageio.imsave(os.path.join(masksPath_s, str(epoch) + '_' + str(batch_idx*config.batch_size+index+1) + '.jpg'), target_s.cpu().detach().numpy()) + + # output_s = outputsfolder_s[index] + # output_s = output_s.squeeze() + # imageio.imsave(os.path.join(outputsPath_s, str(epoch) + '_' + str(batch_idx*config.batch_size+index+1) + '.jpg'), output_s.cpu().detach().numpy()) + + # map = uncertaintyfolder[index] + # map = map.squeeze() + # imageio.imsave(os.path.join(uncertaintyPath, str(epoch) + '_' + str(batch_idx*config.batch_size+index+1) + '.jpg'), map.cpu().detach().numpy()) + + print('Epoch:{}\tbatch_idx:{}/All_batch:{}\tdice_i:{:.4f}\tdice_u:{:.4f}\tdice_s:{:.4f}\tiou_i:{:.4f}\tiou_u:{:.4f}\tiou_s:{:.4f}\tnsd_i:{:.4f}\tnsd_u:{:.4f}\tnsd_s:{:.4f}' + .format(epoch, batch_idx, len(dataloader_val), np.mean(np.array(dices_all_i)), np.mean(np.array(dices_all_u)), np.mean(np.array(dices_all_s)), np.mean(np.array(ious_all_i)), np.mean(np.array(ious_all_u)), np.mean(np.array(ious_all_s)), np.mean(np.array(nsds_all_i)), np.mean(np.array(nsds_all_u)), np.mean(np.array(nsds_all_s)))) + with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f: + f.write('\tdice_i:{:.4f}\tdice_u:{:.4f}\tdice_s:{:.4f}\tiou_i:{:.4f}\tiou_u:{:.4f}\tiou_s:{:.4f}\tnsd_i:{:.4f}\tnsd_u:{:.4f}\tnsd_s:{:.4f}'.format(np.mean(np.array(dices_all_i)), np.mean(np.array(dices_all_u)), np.mean(np.array(dices_all_s)), np.mean(np.array(ious_all_i)), np.mean(np.array(ious_all_u)), np.mean(np.array(ious_all_s)), np.mean(np.array(nsds_all_i)), np.mean(np.array(nsds_all_u)), np.mean(np.array(nsds_all_s)))+'\n') + + # Save checkpoint. + if config.resume is False: + dice = np.mean(np.array(dices_all_s)) + print('Test accuracy: ', dice) + state = { + 'model': model.state_dict(), + 'dice': dice, + 'epoch': epoch, + 'optimizer': optimizer.state_dict() + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(state, './checkpoint/'+str(model_name)+'.pth.tar') + + is_best = False + if best_dice < dice: + best_dice = dice + is_best = True + + if is_best: + shutil.copyfile('./checkpoint/' + str(model_name) + '.pth.tar', + './checkpoint/' + str(model_name) + '_best.pth.tar') + print('Save Successfully') + print('------------------------------------------------------------------------') + +if __name__ == '__main__': + + if config.resume: + test(start_epoch) + else: + for epoch in tqdm(range(start_epoch, config.epochs)): + train(epoch) + test(epoch)