a b/inference.py
1
import argparse
2
import os
3
4
import numpy as np
5
import torch
6
from matplotlib import pyplot as plt
7
from matplotlib.backends.backend_agg import FigureCanvasAgg
8
from medpy.filter.binary import largest_connected_component
9
from skimage.io import imsave
10
from torch.utils.data import DataLoader
11
from tqdm import tqdm
12
13
from dataset import BrainSegmentationDataset as Dataset
14
from unet import UNet
15
from utils import dsc, gray2rgb, outline
16
17
18
def main(args):
19
    makedirs(args)
20
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)
21
22
    loader = data_loader(args)
23
24
    with torch.set_grad_enabled(False):
25
        unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
26
        state_dict = torch.load(args.weights, map_location=device)
27
        unet.load_state_dict(state_dict)
28
        unet.eval()
29
        unet.to(device)
30
31
        input_list = []
32
        pred_list = []
33
        true_list = []
34
35
        for i, data in tqdm(enumerate(loader)):
36
            x, y_true = data
37
            x, y_true = x.to(device), y_true.to(device)
38
39
            y_pred = unet(x)
40
            y_pred_np = y_pred.detach().cpu().numpy()
41
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
42
43
            y_true_np = y_true.detach().cpu().numpy()
44
            true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
45
46
            x_np = x.detach().cpu().numpy()
47
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])
48
49
    volumes = postprocess_per_volume(
50
        input_list,
51
        pred_list,
52
        true_list,
53
        loader.dataset.patient_slice_index,
54
        loader.dataset.patients,
55
    )
56
57
    dsc_dist = dsc_distribution(volumes)
58
59
    dsc_dist_plot = plot_dsc(dsc_dist)
60
    imsave(args.figure, dsc_dist_plot)
61
62
    for p in volumes:
63
        x = volumes[p][0]
64
        y_pred = volumes[p][1]
65
        y_true = volumes[p][2]
66
        for s in range(x.shape[0]):
67
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
68
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
69
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
70
            filename = "{}-{}.png".format(p, str(s).zfill(2))
71
            filepath = os.path.join(args.predictions, filename)
72
            imsave(filepath, image)
73
74
75
def data_loader(args):
76
    dataset = Dataset(
77
        images_dir=args.images,
78
        subset="validation",
79
        image_size=args.image_size,
80
        random_sampling=False,
81
    )
82
    loader = DataLoader(
83
        dataset, batch_size=args.batch_size, drop_last=False, num_workers=1
84
    )
85
    return loader
86
87
88
def postprocess_per_volume(
89
    input_list, pred_list, true_list, patient_slice_index, patients
90
):
91
    volumes = {}
92
    num_slices = np.bincount([p[0] for p in patient_slice_index])
93
    index = 0
94
    for p in range(len(num_slices)):
95
        volume_in = np.array(input_list[index : index + num_slices[p]])
96
        volume_pred = np.round(
97
            np.array(pred_list[index : index + num_slices[p]])
98
        ).astype(int)
99
        volume_pred = largest_connected_component(volume_pred)
100
        volume_true = np.array(true_list[index : index + num_slices[p]])
101
        volumes[patients[p]] = (volume_in, volume_pred, volume_true)
102
        index += num_slices[p]
103
    return volumes
104
105
106
def dsc_distribution(volumes):
107
    dsc_dict = {}
108
    for p in volumes:
109
        y_pred = volumes[p][1]
110
        y_true = volumes[p][2]
111
        dsc_dict[p] = dsc(y_pred, y_true, lcc=False)
112
    return dsc_dict
113
114
115
def plot_dsc(dsc_dist):
116
    y_positions = np.arange(len(dsc_dist))
117
    dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1])
118
    values = [x[1] for x in dsc_dist]
119
    labels = [x[0] for x in dsc_dist]
120
    labels = ["_".join(l.split("_")[1:-1]) for l in labels]
121
    fig = plt.figure(figsize=(12, 8))
122
    canvas = FigureCanvasAgg(fig)
123
    plt.barh(y_positions, values, align="center", color="skyblue")
124
    plt.yticks(y_positions, labels)
125
    plt.xticks(np.arange(0.0, 1.0, 0.1))
126
    plt.xlim([0.0, 1.0])
127
    plt.gca().axvline(np.mean(values), color="tomato", linewidth=2)
128
    plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2)
129
    plt.xlabel("Dice coefficient", fontsize="x-large")
130
    plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1)
131
    plt.tight_layout()
132
    canvas.draw()
133
    plt.close()
134
    s, (width, height) = canvas.print_to_buffer()
135
    return np.fromstring(s, np.uint8).reshape((height, width, 4))
136
137
138
def makedirs(args):
139
    os.makedirs(args.predictions, exist_ok=True)
140
141
142
if __name__ == "__main__":
143
    parser = argparse.ArgumentParser(
144
        description="Inference for segmentation of brain MRI"
145
    )
146
    parser.add_argument(
147
        "--device",
148
        type=str,
149
        default="cuda:0",
150
        help="device for training (default: cuda:0)",
151
    )
152
    parser.add_argument(
153
        "--batch-size",
154
        type=int,
155
        default=32,
156
        help="input batch size for training (default: 32)",
157
    )
158
    parser.add_argument(
159
        "--weights", type=str, required=True, help="path to weights file"
160
    )
161
    parser.add_argument(
162
        "--images", type=str, default="./kaggle_3m", help="root folder with images"
163
    )
164
    parser.add_argument(
165
        "--image-size",
166
        type=int,
167
        default=256,
168
        help="target input image size (default: 256)",
169
    )
170
    parser.add_argument(
171
        "--predictions",
172
        type=str,
173
        default="./predictions",
174
        help="folder for saving images with prediction outlines",
175
    )
176
    parser.add_argument(
177
        "--figure",
178
        type=str,
179
        default="./dsc.png",
180
        help="filename for DSC distribution figure",
181
    )
182
183
    args = parser.parse_args()
184
    main(args)