a b/test.py
1
import argparse
2
import os
3
4
import numpy as np
5
import torch
6
7
from skimage.io import imsave
8
from torch.utils.data import DataLoader
9
import albumentations as A
10
from albumentations.pytorch import ToTensor
11
from torchvision import transforms
12
from tqdm import tqdm
13
import cv2
14
15
from common. dataset import MedicalImageDataset as Dataset
16
from common.loss import bce_dice_loss,  dice_coef_metric
17
from model.Att_Unet import Att_Unet
18
from common.utils import log_images, gray2rgb,  outline
19
20
21
def main(config):
22
23
    makedirs(config)
24
    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
25
   # logger = Logger(config.logs)
26
    loader,dataset = data_loader(config)
27
28
    with torch.set_grad_enabled(False):
29
        unet = Att_Unet()
30
        state_dict = torch.load(config.weights, map_location=device)
31
        unet.load_state_dict(state_dict)
32
        unet.eval()
33
        unet.to(device)
34
35
36
37
        input_list = []
38
        pred_list = []
39
        #true_list = []
40
        #loss_test = []
41
        step=0
42
43
        for i, data in tqdm(enumerate(loader)):
44
            step += 1
45
            x= data
46
            x= x.to(device)
47
48
            y_pred = unet(x)
49
50
            #loss = bce_dice_loss(y_pred, y_true)
51
            #loss_test.append(loss.item())
52
53
            y_pred_np = y_pred.detach().cpu().numpy()
54
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
55
56
            #y_true_np = y_true.detach().cpu().numpy()
57
            #true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
58
59
            x_np = x.detach().cpu().numpy()
60
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])
61
62
        if config.mask_outline==True:
63
64
            for i in range(len(input_list)):
65
                image = gray2rgb(np.squeeze(input_list[i][0,:,:]))
66
                image = outline(image, pred_list[i][0,:,:], color=[255, 0, 0])
67
                image=image.astype("uint8")
68
                filename="{}.png".format(i)
69
                filepath = os.path.join(config.predictions, filename)
70
                imsave(filepath, image)
71
72
        else:
73
            for i in range(len(input_list)):
74
                 img_name=dataset.imgs[i].split("/" )[-1:][0]
75
                 org_shape=cv2.imread(dataset.imgs[i]).shape
76
                 mask=pred_list[i][0,:,:]
77
                 mask[np.nonzero(mask<0.3)]=0.0
78
                 mask[np.nonzero(mask>0.3)]=255.0
79
                 mask=mask.astype("uint8")
80
                 mask_copy=mask.copy()
81
                 mk=cv2.resize(mask,(org_shape[0],org_shape[1]),interpolation = cv2.INTER_AREA)
82
83
84
                 filename="{}".format(img_name)
85
                 filepath = os.path.join(config.predictions, filename)
86
                 imsave(filepath, mk)
87
88
89
90
91
92
93
94
95
def makedirs(config):
96
    os.makedirs(config.predictions, exist_ok=True)
97
98
data_transforms = A.Compose ([
99
    A.Resize(width = 256, height = 256, p=1.0),
100
    A.Normalize( p=1.0),
101
    ToTensor()
102
])
103
104
105
def data_loader(config):
106
    dataset = Dataset('test', config.root,
107
                    transform=data_transforms
108
                    )
109
110
    loader =DataLoader(
111
        dataset,
112
        batch_size=config.batch_size,
113
        num_workers=4
114
    )
115
116
    return loader,dataset
117
118
119
120
121
"""
122
def compute_iou(model, loader, threshold=0.3):
123
124
    Computes accuracy on the dataset wrapped in a loader
125
    Returns: accuracy as a float value between 0 and 1
126
127
    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
128
    #model.eval()
129
    valloss = 0
130
131
    with torch.no_grad():
132
133
        for i_step, (data, target) in enumerate(loader):
134
135
            data = data.to(device)
136
            target = target.to(device)
137
138
139
            #prediction = model(x_gpu)
140
141
            outputs = model(data)
142
           # print("val_output:", outputs.shape)
143
144
            out_cut = np.copy(outputs.data.cpu().numpy())
145
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
146
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
147
148
            picloss = dice_coef_metric(out_cut, target.data.cpu().numpy())
149
            valloss += picloss
150
151
        #print("Threshold:  " + str(threshold) + "  Validation DICE score:", valloss / i_step)
152
153
    return valloss / i_step
154
155
"""
156
157
"""
158
def log_loss_summary(logger, loss, step, prefix=""):
159
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)
160
"""
161
162
163
164
if __name__ == "__main__":
165
    parser = argparse.ArgumentParser(
166
        description="Inference for segmentation of polyps in GI"
167
    )
168
    parser.add_argument(
169
        "--device",
170
        type=str,
171
        default="cuda:0",
172
        help="device for training (default: cuda:0)",
173
    )
174
    parser.add_argument(
175
        "--batch-size",
176
        type=int,
177
        default=32,
178
        help="input batch size for training (default: 32)",
179
    )
180
    parser.add_argument(
181
        "--weights", type=str, default="./weights/unet.pt", required=True, help="path to weights file"
182
    )
183
    parser.add_argument(
184
        "--root", type=str, default="./medico2020", help="root folder with images"
185
    )
186
187
    parser.add_argument(
188
        "--mask_outline", type=bool, default=False, help="If True ,draws border line of mask else returns binary mask"
189
    )
190
    parser.add_argument(
191
        "--image-size",
192
        type=int,
193
        default=256,
194
        help="target input image size (default: 256)",
195
    )
196
    parser.add_argument(
197
        "--predictions",
198
        type=str,
199
        default="./predictions",
200
        help="folder for saving images with prediction outlines",
201
    )
202
    parser.add_argument(
203
        "--vis-images",
204
        type=int,
205
        default=20,
206
        help="number of visualization images to save in log file (default: 200)",
207
    )
208
    parser.add_argument(
209
        "--vis-freq",
210
        type=int,
211
        default=10,
212
        help="frequency of saving images to log file (default: 10)",
213
    )
214
215
    parser.add_argument(
216
        "--logs", type=str, default="./test_logs", help="folder to save logs"
217
    )
218
219
    config = parser.parse_args()
220
    main(config)