Diff of /train_infer.py [000000] .. [2162c1]

Switch to side-by-side view

--- a
+++ b/train_infer.py
@@ -0,0 +1,252 @@
+""" Training augmented model """
+import os
+import torch
+import torch.nn as nn
+import numpy as np
+from tensorboardX import SummaryWriter
+from ptflops import get_model_complexity_info
+import utils
+import data_generator_3D as data_generator_3D
+import time
+import SimpleITK as sitk
+import sys
+from config import TrainConfig
+from model import LCOVNet
+from apex import amp
+
+
+config = TrainConfig()
+
+device = torch.device("cuda")
+
+# tensorboard
+writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
+writer.add_text('config', config.as_markdown(), 0)
+
+logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name)))
+config.print_params(logger.info)
+
+def main():
+    logger.info("Logger is set - training start")
+
+    # set default gpu device id
+    torch.cuda.set_device(config.gpus[0])
+
+    # set seed
+    np.random.seed(config.seed)
+    torch.manual_seed(config.seed)
+    torch.cuda.manual_seed_all(config.seed)
+
+    torch.backends.cudnn.benchmark = True
+
+
+    criterion = utils.log_loss().to(device)
+    d = torch.device(type='cuda', index=config.gpus[0])
+    model = LCOVNet(config.input_channels, config.n_classes).to(device=d)
+    with torch.cuda.device(config.gpus[0]):
+        net = model
+        macs, params = get_model_complexity_info(net, (1, 240, 160, 48), as_strings=True,
+                                                 print_per_layer_stat=True, verbose=True)
+        logger.info("{:<30}  {:<8}".format('Computational complexity: ', macs))
+        logger.info("{:<30}  {:<8}".format('Number of parameters: ', params))
+
+    # model size
+    mb_params = utils.param_size(model)
+    logger.info("Model size = {:.3f} MB".format(mb_params))
+    # weights optimizer
+    optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum,
+                                weight_decay=config.weight_decay)
+    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
+
+    train_loader = data_generator_3D.Covid19TrainSet()
+    valid_loader = data_generator_3D.Covid19EvalSet()
+    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
+
+    best_dice = 0.
+    # training loop
+    summ_writer = SummaryWriter(config.training_summary_dir)
+    for epoch in range(config.epochs):
+
+        # training
+        train(train_loader, model, optimizer, criterion, epoch, summ_writer)
+        lr_scheduler.step()
+        # validation
+        cur_step = (epoch+1) * len(train_loader)
+        mean_dice = validate(valid_loader, model, criterion, epoch, summ_writer, best_dice)
+
+        # save
+        if best_dice < mean_dice:
+            best_dice = mean_dice
+            is_best = True
+        else:
+            is_best = False
+        utils.save_checkpoint(model, config.path, is_best)
+        print("")
+
+    logger.info("Final best Dice = {:.4%}".format(best_dice))
+    utils.save_results(best_dice, config.path)
+    summ_writer.close()
+
+def train(train_loader, model, optimizer, criterion, epoch, summ_writer):
+    losses = utils.AverageMeter()
+    cur_step = epoch*len(train_loader)
+    cur_lr = optimizer.param_groups[0]['lr']
+    logger.info("Epoch {} LR {}".format(epoch, cur_lr))
+    writer.add_scalar('train/lr', cur_lr, cur_step)
+    model.train()
+    #all_dice = np.empty().astype(np.float32)
+    all_dice = []
+    for step, (name, X, y) in enumerate(train_loader):
+        X, y = torch.from_numpy(X).to(device, non_blocking=True), torch.from_numpy(y).to(device, non_blocking=True)
+        N = X.size(0)
+
+        optimizer.zero_grad()
+        logits = model(X)
+
+        loss = criterion(logits, y)
+        #loss.backward()
+        with amp.scale_loss(loss, optimizer) as scaled_loss:
+            scaled_loss.backward()
+        # gradient clipping
+        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
+        optimizer.step()
+
+        losses.update(loss.item(), N)
+
+        if step % config.print_freq == 0 or step == len(train_loader)-1:
+           logger.info(
+                "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {:.3f} ".format(
+                    epoch+1, config.epochs, step, len(train_loader), losses.avg,
+                    ))
+
+        writer.add_scalar('train/loss', loss.item(), cur_step)
+
+        logits[logits >= 0.5] = 1
+        logits[logits < 0.5] = 0
+        predict = logits.cpu().detach().numpy()
+        y = y.cpu().detach().numpy()
+        dice_i = utils.evaluate(predict, y)
+        all_dice.append(dice_i)
+        cur_step += 1
+    dice_mean = 0
+
+    for i in all_dice:
+        dice_mean += i/len(all_dice)
+
+    train_avg_loss = losses.avg
+    train_avg_dice = dice_mean
+    loss_scalers = {'train': train_avg_loss}
+    summ_writer.add_scalars('loss', loss_scalers, epoch + 1)
+
+    dice_scalers = {'train': train_avg_dice}
+    summ_writer.add_scalars('avg_dice', dice_scalers, epoch + 1)
+
+    if (epoch+1) % 50 == 0:
+        chpt_prefx = config.training_checkpoint_prefix
+        save_dict = {'epoch': epoch + 1,
+                     'model_state_dict': model.state_dict(),
+                     'optimizer_state_dict': optimizer.state_dict(),
+                     'amp': amp.state_dict()}
+        save_name = "{0:}_{1:}.pt".format(chpt_prefx, epoch + 1)
+        torch.save(save_dict, save_name)
+    print("train_avg_loss", train_avg_loss)
+    print("train_avg_dice", train_avg_dice)
+
+def validate(valid_loader, model, criterion, epoch, summ_writer, best_dice):
+    losses = utils.AverageMeter()
+
+    model.eval()
+    all_dice = np.zeros([len(valid_loader)]).astype(np.float32)
+    all_dice = []
+    totel_time = 0
+    start_time = time.time()
+    size_z = 48
+    with torch.no_grad():
+        for i, (name, image, label) in enumerate(valid_loader):
+            image = torch.from_numpy(image)
+            predict = np.zeros(shape=label.shape, dtype=label.dtype)
+            z = image.shape[4]
+            m = z // size_z if z % size_z == 0 else z // size_z + 1
+            start_time = time.time()
+            for k in range(m):
+                if (k+1)*size_z <= z:
+                    max_z = (k+1)*size_z
+                else:
+                    max_z = z
+                min_z = max_z - size_z
+                image_k = image[:, :, :, :, min_z:max_z].float().to(device, non_blocking=True)
+                predict_k = model(image_k)
+                predict_k[predict_k >= 0.5] = 1
+                predict_k[predict_k < 0.5] = 0
+                predict[:, :, :, :, min_z:max_z] = predict_k.cpu().detach().numpy()
+            totel_time = totel_time + time.time() - start_time
+            all_dice.append(utils.evaluate(predict, label))
+
+    dice_len = len(all_dice)
+    dice_np = np.empty(shape=[dice_len])
+    #list_image = []
+    for i in range(dice_len):
+        dice_np[i] = all_dice[i]
+        logger.info("{}  dice: {:.4%} ".format(i, all_dice[i]))
+    logger.info("mean: {}".format(dice_np.mean()))
+    logger.info("std : {}".format(dice_np.std()))
+
+    if best_dice < dice_np.mean():
+        chpt_prefx = config.validing_checkpoint_prefix
+        save_dict = {'epoch': epoch + 1,
+                     'model_state_dict': model.state_dict(),
+                     'amp': amp.state_dict()}
+        fname = "{}/best.pt".format(chpt_prefx)
+        if os.path.isfile(fname):
+            os.remove(fname)
+        save_name = "{}/best.pt".format(chpt_prefx)
+        torch.save(save_dict, save_name)
+
+    dice_scalers = {'vadil': dice_np.mean()}
+    summ_writer.add_scalars('vadil_avg_dice', dice_scalers, epoch + 1)
+
+    avg_time = totel_time / dice_len
+    logger.info("average testing time : {}".format(avg_time))
+
+    mean_dice = np.mean(all_dice, axis = 0)
+    writer.add_scalar('val/dice', mean_dice, epoch)
+    writer.add_scalar('val/loss', losses.avg, epoch)
+    logger.info("Valid: [{:2d}/{}] average dice: {:.4%} ".format(epoch+1, config.epochs, mean_dice))
+
+    return mean_dice
+
+
+
+def save_nd_array_as_image(data, image_name, reference_name = None):
+    """
+    save a 3D or 2D numpy array as medical image or RGB image
+    inputs:
+        data: a numpy array with shape [D, H, W] or [C, H, W]
+        image_name: the output file name
+    outputs: None
+    """
+    data_dim = len(data.shape)
+    assert(data_dim == 2 or data_dim == 3)
+    if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or
+        image_name.endswith(".mha")):
+        assert(data_dim == 3)
+        save_array_as_nifty_volume(data, image_name, reference_name)
+
+def save_array_as_nifty_volume(data, image_name, reference_name = None):
+    """
+    save a numpy array as nifty image
+    inputs:
+        data: a numpy array with shape [Depth, Height, Width]
+        image_name: the ouput file name
+        reference_name: file name of the reference image of which affine and header are used
+    outputs: None
+    """
+    img = sitk.GetImageFromArray(data)
+    if(reference_name is not None):
+        img_ref = sitk.ReadImage(reference_name)
+        img.CopyInformation(img_ref)
+    sitk.WriteImage(img, image_name)
+
+
+if __name__ == "__main__":
+    main()