Diff of /main.py [000000] .. [f804b3]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,304 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torch
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from tqdm import tqdm
+import albumentations as A
+from albumentations.pytorch import ToTensor
+
+
+from common. dataset import MedicalImageDataset as Dataset
+from common.logger import Logger
+from common.loss import bce_dice_loss, dice_coef_metric,_fast_hist, jaccard_index
+from model.Att_Unet import Att_Unet
+from common.utils import log_images
+
+
+def main(config):
+    makedirs(config)
+    snapshotargs(config)
+    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
+
+    loader_train, loader_valid = data_loaders(config)
+    loaders = {"train": loader_train, "valid": loader_valid}
+
+    unet =Att_Unet()
+    unet.to(device)
+
+
+    best_validation_dsc = 0.0
+
+    optimizer = optim.Adam(unet.parameters(), lr=config.lr,weight_decay=1e-5)
+    lr_scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15, verbose=False)
+
+
+    logger = Logger(config.logs)
+    loss_train = []
+    loss_valid = []
+
+    step = 0
+
+    for epoch in tqdm(range(config.epochs), total=config.epochs):
+        for phase in ["train", "valid"]:
+                if phase == "train":
+                    unet.train()
+                else:
+                    unet.eval()
+
+                validation_pred = []
+                validation_true = []
+                running_loss = 0.0
+
+                for i, data in enumerate(loaders[phase]):
+                    if phase == "train":
+                        step += 1
+
+                    x, y_true = data
+                    x, y_true = x.to(device), y_true.to(device)
+
+                    optimizer.zero_grad()
+
+                    with torch.set_grad_enabled(phase == "train"):
+                            y_pred = unet(x)
+
+                            loss = bce_dice_loss(y_pred, y_true)
+
+                            if phase == "valid":
+                                loss_valid.append(loss.item())
+                                y_pred_np = y_pred.detach().cpu().numpy()
+                                validation_pred.extend(
+                                    [y_pred_np[s] for s in range(y_pred_np.shape[0])]
+                                )
+                                y_true_np = y_true.detach().cpu().numpy()
+                                validation_true.extend(
+                                    [y_true_np[s] for s in range(y_true_np.shape[0])]
+                                )
+                                if (epoch % config.vis_freq == 0) or (epoch == config.epochs - 1):
+                                    if i * config.batch_size < config.vis_images:
+                                        tag = "image/{}".format(i)
+                                        num_images = config.vis_images - i * config.batch_size
+                                        logger.image_list_summary(
+                                            tag,
+                                            log_images(x, y_true, y_pred)[:num_images],
+                                            step,
+                                        )
+
+                            if phase == "train":
+                                loss_train.append(loss.item())
+                                loss.backward()
+                                optimizer.step()
+                            running_loss += loss.detach() * x.size(0)
+
+                    if i % 50 == 0:
+                        for param_group in optimizer.param_groups:
+                            print("Current learning rate is: {}".format(param_group['lr']))
+
+
+
+                    if phase == "train" and (step + 1) % 10 == 0:
+                        log_loss_summary(logger, loss_train, step)
+                        loss_train = []
+
+                print('Epoch [%d/%d], Loss: %.4f, ' %(epoch+1, config.epochs, running_loss/len(loaders[phase].dataset)))
+
+                if phase == "valid":
+                    log_loss_summary(logger, loss_valid, step, prefix="val_")
+                    mean_dsc,mean_iou = compute_metric(unet,loaders[phase])
+                    logger.scalar_summary("val_dsc", mean_dsc, step)
+                    logger.scalar_summary("val_iou", mean_iou, step)
+                    lr_scheduler.step(mean_dsc)
+                    print("\nMean DICE on validation:", mean_dsc)
+                    print("Mean IOU on validation:", mean_iou)
+                    print("..........................................")
+
+                    if mean_dsc > best_validation_dsc:
+                        best_validation_dsc = mean_dsc
+                        torch.save(unet.state_dict(), os.path.join(config.weights, "unet.pt"))
+                    loss_valid = []
+
+
+
+    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
+
+
+def data_loaders(config):
+    dataset_train, dataset_valid = datasets(config)
+
+
+
+    loader_train = DataLoader(
+        dataset_train,
+        batch_size=config.batch_size,
+        num_workers=config.workers
+    )
+    loader_valid = DataLoader(
+        dataset_valid,
+        batch_size=config.batch_size,
+        num_workers=config.workers
+
+    )
+
+    return loader_train, loader_valid
+
+data_transforms = A.Compose ([
+    A.Resize(width = 256, height = 256, p=1.0),
+    A.HorizontalFlip(p=0.5),
+    A.VerticalFlip(p=0.5),
+    A.Rotate((-5,5),p=0.5),
+    A.RandomSunFlare(flare_roi=(0, 0, 1, 0.5), angle_lower=0, angle_upper=1,
+                                   num_flare_circles_lower=1, num_flare_circles_upper=2,
+                                   src_radius=160, src_color=(255, 255, 255),  always_apply=False, p=0.2),
+     A.RGBShift (r_shift_limit=10, g_shift_limit=10,
+                 b_shift_limit=10, always_apply=False, p=0.2),
+    A. ElasticTransform (alpha=2, sigma=15, alpha_affine=25, interpolation=1,
+                                      border_mode=4, value=None, mask_value=None,
+                                      always_apply=False, approximate=False, p=0.2) ,
+    A.Normalize( p=1.0),
+    ToTensor(),
+])
+
+
+def datasets(config):
+    train = Dataset('train', config.root,
+                    transform=data_transforms)
+
+
+    valid = Dataset('val', config.root,
+                    transform=data_transforms)
+
+    return train, valid
+
+
+def compute_metric(model, loader, threshold=0.3):
+    """
+    Computes accuracy on the dataset wrapped in a loader
+
+    Returns: accuracy as a float value between 0 and 1
+    """
+    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
+    #model.eval()
+    valloss_one = 0
+    valloss_two = 0
+
+    with torch.no_grad():
+
+        for i_step, (data, target) in enumerate(loader):
+
+            data = data.to(device)
+            target = target.to(device)
+
+
+            #prediction = model(x_gpu)
+
+            outputs = model(data)
+           # print("val_output:", outputs.shape)
+
+            out_cut = np.copy(outputs.data.cpu().numpy())
+            out_cut[np.nonzero(out_cut < threshold)] = 0.0
+            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
+            hist=_fast_hist(target.data.cpu().numpy(),out_cut,num_classes=2)
+
+            picloss = dice_coef_metric(hist)
+            iouloss,_=jaccard_index(hist)
+            valloss_one += picloss
+            valloss_two +=iouloss
+
+
+    return valloss_one / i_step,valloss_two/i_step
+
+
+def log_loss_summary(logger, loss, step, prefix=""):
+    logger.scalar_summary(prefix + "loss", np.mean(loss), step)
+
+
+def makedirs(config):
+    os.makedirs(config.weights, exist_ok=True)
+    os.makedirs(config.logs, exist_ok=True)
+
+
+def snapshotargs(config):
+    config_file = os.path.join(config.logs, "config.json")
+    with open(config_file, "w") as fp:
+        json.dump(vars(config), fp)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Finetuning pretrained Unet"
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=16,
+        help="input batch size for training (default: 16)",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=100,
+        help="number of epochs to train (default: 100)",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=0.001,
+        help="initial learning rate (default: 0.001)",
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda:0",
+        help="device for training (default: cuda:0)",
+    )
+    parser.add_argument(
+        "--workers",
+        type=int,
+        default=4,
+        help="number of workers for data loading (default: 4)",
+    )
+    parser.add_argument(
+        "--vis-images",
+        type=int,
+        default=100,
+        help="number of visualization images to save in log file (default: 200)",
+    )
+    parser.add_argument(
+        "--vis-freq",
+        type=int,
+        default=10,
+        help="frequency of saving images to log file (default: 10)",
+    )
+    parser.add_argument(
+        "--weights", type=str, default="./weights", help="folder to save weights"
+    )
+    parser.add_argument(
+        "--logs", type=str, default="./logs", help="folder to save logs"
+    )
+    parser.add_argument(
+        "--root", type=str, default="./medico2020", help="root folder with images"
+    )
+    parser.add_argument(
+        "--image-size",
+        type=int,
+        default=256,
+        help="target input image size (default: 256)",
+    )
+    parser.add_argument(
+        "--aug-scale",
+        type=int,
+        default=0.05,
+        help="scale factor range for augmentation (default: 0.05)",
+    )
+    parser.add_argument(
+        "--aug-angle",
+        type=int,
+        default=6,
+        help="rotation angle range in degrees for augmentation (default: 15)",
+    )
+    config = parser.parse_args()
+    main(config)