Diff of /test.py [000000] .. [f804b3]

Switch to side-by-side view

--- a
+++ b/test.py
@@ -0,0 +1,220 @@
+import argparse
+import os
+
+import numpy as np
+import torch
+
+from skimage.io import imsave
+from torch.utils.data import DataLoader
+import albumentations as A
+from albumentations.pytorch import ToTensor
+from torchvision import transforms
+from tqdm import tqdm
+import cv2
+
+from common. dataset import MedicalImageDataset as Dataset
+from common.loss import bce_dice_loss,  dice_coef_metric
+from model.Att_Unet import Att_Unet
+from common.utils import log_images, gray2rgb,  outline
+
+
+def main(config):
+
+    makedirs(config)
+    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
+   # logger = Logger(config.logs)
+    loader,dataset = data_loader(config)
+
+    with torch.set_grad_enabled(False):
+        unet = Att_Unet()
+        state_dict = torch.load(config.weights, map_location=device)
+        unet.load_state_dict(state_dict)
+        unet.eval()
+        unet.to(device)
+
+
+
+        input_list = []
+        pred_list = []
+        #true_list = []
+        #loss_test = []
+        step=0
+
+        for i, data in tqdm(enumerate(loader)):
+            step += 1
+            x= data
+            x= x.to(device)
+
+            y_pred = unet(x)
+
+            #loss = bce_dice_loss(y_pred, y_true)
+            #loss_test.append(loss.item())
+
+            y_pred_np = y_pred.detach().cpu().numpy()
+            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
+
+            #y_true_np = y_true.detach().cpu().numpy()
+            #true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
+
+            x_np = x.detach().cpu().numpy()
+            input_list.extend([x_np[s] for s in range(x_np.shape[0])])
+
+        if config.mask_outline==True:
+
+            for i in range(len(input_list)):
+                image = gray2rgb(np.squeeze(input_list[i][0,:,:]))
+                image = outline(image, pred_list[i][0,:,:], color=[255, 0, 0])
+                image=image.astype("uint8")
+                filename="{}.png".format(i)
+                filepath = os.path.join(config.predictions, filename)
+                imsave(filepath, image)
+
+        else:
+            for i in range(len(input_list)):
+                 img_name=dataset.imgs[i].split("/" )[-1:][0]
+                 org_shape=cv2.imread(dataset.imgs[i]).shape
+                 mask=pred_list[i][0,:,:]
+                 mask[np.nonzero(mask<0.3)]=0.0
+                 mask[np.nonzero(mask>0.3)]=255.0
+                 mask=mask.astype("uint8")
+                 mask_copy=mask.copy()
+                 mk=cv2.resize(mask,(org_shape[0],org_shape[1]),interpolation = cv2.INTER_AREA)
+
+
+                 filename="{}".format(img_name)
+                 filepath = os.path.join(config.predictions, filename)
+                 imsave(filepath, mk)
+
+
+
+
+
+
+
+
+def makedirs(config):
+    os.makedirs(config.predictions, exist_ok=True)
+
+data_transforms = A.Compose ([
+    A.Resize(width = 256, height = 256, p=1.0),
+    A.Normalize( p=1.0),
+    ToTensor()
+])
+
+
+def data_loader(config):
+    dataset = Dataset('test', config.root,
+                    transform=data_transforms
+                    )
+
+    loader =DataLoader(
+        dataset,
+        batch_size=config.batch_size,
+        num_workers=4
+    )
+
+    return loader,dataset
+
+
+
+
+"""
+def compute_iou(model, loader, threshold=0.3):
+
+    Computes accuracy on the dataset wrapped in a loader
+    Returns: accuracy as a float value between 0 and 1
+
+    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
+    #model.eval()
+    valloss = 0
+
+    with torch.no_grad():
+
+        for i_step, (data, target) in enumerate(loader):
+
+            data = data.to(device)
+            target = target.to(device)
+
+
+            #prediction = model(x_gpu)
+
+            outputs = model(data)
+           # print("val_output:", outputs.shape)
+
+            out_cut = np.copy(outputs.data.cpu().numpy())
+            out_cut[np.nonzero(out_cut < threshold)] = 0.0
+            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
+
+            picloss = dice_coef_metric(out_cut, target.data.cpu().numpy())
+            valloss += picloss
+
+        #print("Threshold:  " + str(threshold) + "  Validation DICE score:", valloss / i_step)
+
+    return valloss / i_step
+
+"""
+
+"""
+def log_loss_summary(logger, loss, step, prefix=""):
+    logger.scalar_summary(prefix + "loss", np.mean(loss), step)
+"""
+
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Inference for segmentation of polyps in GI"
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda:0",
+        help="device for training (default: cuda:0)",
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=32,
+        help="input batch size for training (default: 32)",
+    )
+    parser.add_argument(
+        "--weights", type=str, default="./weights/unet.pt", required=True, help="path to weights file"
+    )
+    parser.add_argument(
+        "--root", type=str, default="./medico2020", help="root folder with images"
+    )
+
+    parser.add_argument(
+        "--mask_outline", type=bool, default=False, help="If True ,draws border line of mask else returns binary mask"
+    )
+    parser.add_argument(
+        "--image-size",
+        type=int,
+        default=256,
+        help="target input image size (default: 256)",
+    )
+    parser.add_argument(
+        "--predictions",
+        type=str,
+        default="./predictions",
+        help="folder for saving images with prediction outlines",
+    )
+    parser.add_argument(
+        "--vis-images",
+        type=int,
+        default=20,
+        help="number of visualization images to save in log file (default: 200)",
+    )
+    parser.add_argument(
+        "--vis-freq",
+        type=int,
+        default=10,
+        help="frequency of saving images to log file (default: 10)",
+    )
+
+    parser.add_argument(
+        "--logs", type=str, default="./test_logs", help="folder to save logs"
+    )
+
+    config = parser.parse_args()
+    main(config)