--- a +++ b/train_infer.py @@ -0,0 +1,252 @@ +""" Training augmented model """ +import os +import torch +import torch.nn as nn +import numpy as np +from tensorboardX import SummaryWriter +from ptflops import get_model_complexity_info +import utils +import data_generator_3D as data_generator_3D +import time +import SimpleITK as sitk +import sys +from config import TrainConfig +from model import LCOVNet +from apex import amp + + +config = TrainConfig() + +device = torch.device("cuda") + +# tensorboard +writer = SummaryWriter(log_dir=os.path.join(config.path, "tb")) +writer.add_text('config', config.as_markdown(), 0) + +logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name))) +config.print_params(logger.info) + +def main(): + logger.info("Logger is set - training start") + + # set default gpu device id + torch.cuda.set_device(config.gpus[0]) + + # set seed + np.random.seed(config.seed) + torch.manual_seed(config.seed) + torch.cuda.manual_seed_all(config.seed) + + torch.backends.cudnn.benchmark = True + + + criterion = utils.log_loss().to(device) + d = torch.device(type='cuda', index=config.gpus[0]) + model = LCOVNet(config.input_channels, config.n_classes).to(device=d) + with torch.cuda.device(config.gpus[0]): + net = model + macs, params = get_model_complexity_info(net, (1, 240, 160, 48), as_strings=True, + print_per_layer_stat=True, verbose=True) + logger.info("{:<30} {:<8}".format('Computational complexity: ', macs)) + logger.info("{:<30} {:<8}".format('Number of parameters: ', params)) + + # model size + mb_params = utils.param_size(model) + logger.info("Model size = {:.3f} MB".format(mb_params)) + # weights optimizer + optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum, + weight_decay=config.weight_decay) + model, optimizer = amp.initialize(model, optimizer, opt_level="O1") + + train_loader = data_generator_3D.Covid19TrainSet() + valid_loader = data_generator_3D.Covid19EvalSet() + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs) + + best_dice = 0. + # training loop + summ_writer = SummaryWriter(config.training_summary_dir) + for epoch in range(config.epochs): + + # training + train(train_loader, model, optimizer, criterion, epoch, summ_writer) + lr_scheduler.step() + # validation + cur_step = (epoch+1) * len(train_loader) + mean_dice = validate(valid_loader, model, criterion, epoch, summ_writer, best_dice) + + # save + if best_dice < mean_dice: + best_dice = mean_dice + is_best = True + else: + is_best = False + utils.save_checkpoint(model, config.path, is_best) + print("") + + logger.info("Final best Dice = {:.4%}".format(best_dice)) + utils.save_results(best_dice, config.path) + summ_writer.close() + +def train(train_loader, model, optimizer, criterion, epoch, summ_writer): + losses = utils.AverageMeter() + cur_step = epoch*len(train_loader) + cur_lr = optimizer.param_groups[0]['lr'] + logger.info("Epoch {} LR {}".format(epoch, cur_lr)) + writer.add_scalar('train/lr', cur_lr, cur_step) + model.train() + #all_dice = np.empty().astype(np.float32) + all_dice = [] + for step, (name, X, y) in enumerate(train_loader): + X, y = torch.from_numpy(X).to(device, non_blocking=True), torch.from_numpy(y).to(device, non_blocking=True) + N = X.size(0) + + optimizer.zero_grad() + logits = model(X) + + loss = criterion(logits, y) + #loss.backward() + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + optimizer.step() + + losses.update(loss.item(), N) + + if step % config.print_freq == 0 or step == len(train_loader)-1: + logger.info( + "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {:.3f} ".format( + epoch+1, config.epochs, step, len(train_loader), losses.avg, + )) + + writer.add_scalar('train/loss', loss.item(), cur_step) + + logits[logits >= 0.5] = 1 + logits[logits < 0.5] = 0 + predict = logits.cpu().detach().numpy() + y = y.cpu().detach().numpy() + dice_i = utils.evaluate(predict, y) + all_dice.append(dice_i) + cur_step += 1 + dice_mean = 0 + + for i in all_dice: + dice_mean += i/len(all_dice) + + train_avg_loss = losses.avg + train_avg_dice = dice_mean + loss_scalers = {'train': train_avg_loss} + summ_writer.add_scalars('loss', loss_scalers, epoch + 1) + + dice_scalers = {'train': train_avg_dice} + summ_writer.add_scalars('avg_dice', dice_scalers, epoch + 1) + + if (epoch+1) % 50 == 0: + chpt_prefx = config.training_checkpoint_prefix + save_dict = {'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'amp': amp.state_dict()} + save_name = "{0:}_{1:}.pt".format(chpt_prefx, epoch + 1) + torch.save(save_dict, save_name) + print("train_avg_loss", train_avg_loss) + print("train_avg_dice", train_avg_dice) + +def validate(valid_loader, model, criterion, epoch, summ_writer, best_dice): + losses = utils.AverageMeter() + + model.eval() + all_dice = np.zeros([len(valid_loader)]).astype(np.float32) + all_dice = [] + totel_time = 0 + start_time = time.time() + size_z = 48 + with torch.no_grad(): + for i, (name, image, label) in enumerate(valid_loader): + image = torch.from_numpy(image) + predict = np.zeros(shape=label.shape, dtype=label.dtype) + z = image.shape[4] + m = z // size_z if z % size_z == 0 else z // size_z + 1 + start_time = time.time() + for k in range(m): + if (k+1)*size_z <= z: + max_z = (k+1)*size_z + else: + max_z = z + min_z = max_z - size_z + image_k = image[:, :, :, :, min_z:max_z].float().to(device, non_blocking=True) + predict_k = model(image_k) + predict_k[predict_k >= 0.5] = 1 + predict_k[predict_k < 0.5] = 0 + predict[:, :, :, :, min_z:max_z] = predict_k.cpu().detach().numpy() + totel_time = totel_time + time.time() - start_time + all_dice.append(utils.evaluate(predict, label)) + + dice_len = len(all_dice) + dice_np = np.empty(shape=[dice_len]) + #list_image = [] + for i in range(dice_len): + dice_np[i] = all_dice[i] + logger.info("{} dice: {:.4%} ".format(i, all_dice[i])) + logger.info("mean: {}".format(dice_np.mean())) + logger.info("std : {}".format(dice_np.std())) + + if best_dice < dice_np.mean(): + chpt_prefx = config.validing_checkpoint_prefix + save_dict = {'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'amp': amp.state_dict()} + fname = "{}/best.pt".format(chpt_prefx) + if os.path.isfile(fname): + os.remove(fname) + save_name = "{}/best.pt".format(chpt_prefx) + torch.save(save_dict, save_name) + + dice_scalers = {'vadil': dice_np.mean()} + summ_writer.add_scalars('vadil_avg_dice', dice_scalers, epoch + 1) + + avg_time = totel_time / dice_len + logger.info("average testing time : {}".format(avg_time)) + + mean_dice = np.mean(all_dice, axis = 0) + writer.add_scalar('val/dice', mean_dice, epoch) + writer.add_scalar('val/loss', losses.avg, epoch) + logger.info("Valid: [{:2d}/{}] average dice: {:.4%} ".format(epoch+1, config.epochs, mean_dice)) + + return mean_dice + + + +def save_nd_array_as_image(data, image_name, reference_name = None): + """ + save a 3D or 2D numpy array as medical image or RGB image + inputs: + data: a numpy array with shape [D, H, W] or [C, H, W] + image_name: the output file name + outputs: None + """ + data_dim = len(data.shape) + assert(data_dim == 2 or data_dim == 3) + if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or + image_name.endswith(".mha")): + assert(data_dim == 3) + save_array_as_nifty_volume(data, image_name, reference_name) + +def save_array_as_nifty_volume(data, image_name, reference_name = None): + """ + save a numpy array as nifty image + inputs: + data: a numpy array with shape [Depth, Height, Width] + image_name: the ouput file name + reference_name: file name of the reference image of which affine and header are used + outputs: None + """ + img = sitk.GetImageFromArray(data) + if(reference_name is not None): + img_ref = sitk.ReadImage(reference_name) + img.CopyInformation(img_ref) + sitk.WriteImage(img, image_name) + + +if __name__ == "__main__": + main()