a b/flair-segmentation/test.py
1
from __future__ import print_function
2
3
import matplotlib
4
5
matplotlib.use("Agg")
6
import cv2
7
import matplotlib.pyplot as plt
8
import numpy as np
9
import os
10
import sys
11
import tensorflow as tf
12
import warnings
13
14
warnings.filterwarnings("ignore")
15
16
from keras import backend as K
17
from scipy.io import savemat
18
from skimage.io import imsave
19
20
from data import load_data
21
from net import unet
22
23
weights_path = "./weights_128.h5"
24
train_images_path = "./data/train/"
25
test_images_path = "./data/valid/"
26
predictions_path = "./predictions/"
27
28
gpu = "0"
29
30
31
def predict(mean=20.0, std=43.0):
32
    # load and normalize data
33
    if mean == 0.0 and std == 1.0:
34
        imgs_train, _, _ = load_data(train_images_path)
35
        mean = np.mean(imgs_train)
36
        std = np.std(imgs_train)
37
38
    imgs_test, imgs_mask_test, names_test = load_data(test_images_path)
39
    original_imgs_test = imgs_test.astype(np.uint8)
40
41
    imgs_test -= mean
42
    imgs_test /= std
43
44
    # load model with weights
45
    model = unet()
46
    model.load_weights(weights_path)
47
48
    # make predictions
49
    imgs_mask_pred = model.predict(imgs_test, verbose=1)
50
51
    # save to mat file for further processing
52
    if not os.path.exists(predictions_path):
53
        os.mkdir(predictions_path)
54
55
    matdict = {
56
        "pred": imgs_mask_pred,
57
        "image": original_imgs_test,
58
        "mask": imgs_mask_test,
59
        "name": names_test,
60
    }
61
    savemat(os.path.join(predictions_path, "predictions.mat"), matdict)
62
63
    # save images with segmentation and ground truth mask overlay
64
    for i in range(len(imgs_test)):
65
        pred = imgs_mask_pred[i]
66
        image = original_imgs_test[i]
67
        mask = imgs_mask_test[i]
68
69
        # segmentation mask is for the middle slice
70
        image_rgb = gray2rgb(image[:, :, 1])
71
72
        # prediction contour image
73
        pred = (np.round(pred[:, :, 0]) * 255.0).astype(np.uint8)
74
        pred, contours, _ = cv2.findContours(
75
            pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
76
        )
77
        pred = np.zeros(pred.shape)
78
        cv2.drawContours(pred, contours, -1, (255, 0, 0), 1)
79
80
        # ground truth contour image
81
        mask = (np.round(mask[:, :, 0]) * 255.0).astype(np.uint8)
82
        mask, contours, _ = cv2.findContours(
83
            mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
84
        )
85
        mask = np.zeros(mask.shape)
86
        cv2.drawContours(mask, contours, -1, (255, 0, 0), 1)
87
88
        # combine image with contours
89
        pred_rgb = np.array(image_rgb)
90
        annotation = pred_rgb[:, :, 1]
91
        annotation[np.maximum(pred, mask) == 255] = 0
92
        pred_rgb[:, :, 0] = pred_rgb[:, :, 1] = pred_rgb[:, :, 2] = annotation
93
        pred_rgb[:, :, 2] = np.maximum(pred_rgb[:, :, 2], mask)
94
        pred_rgb[:, :, 0] = np.maximum(pred_rgb[:, :, 0], pred)
95
96
        imsave(os.path.join(predictions_path, names_test[i] + ".png"), pred_rgb)
97
98
    return imgs_mask_test, imgs_mask_pred, names_test
99
100
101
def evaluate(imgs_mask_test, imgs_mask_pred, names_test):
102
    test_pred = zip(imgs_mask_test, imgs_mask_pred)
103
    name_test_pred = zip(names_test, test_pred)
104
    name_test_pred.sort(key=lambda x: x[0])
105
106
    patient_ids = []
107
    dc_values = []
108
109
    i = 0  # start slice index
110
    for p in range(len(name_test_pred)):
111
        # get case id (names are in format <case_id>_<slice_number>)
112
        p_id = "_".join(name_test_pred[p][0].split("_")[:-1])
113
114
        # if this is the last slice for the processed case
115
        if p + 1 >= len(name_test_pred) or p_id not in name_test_pred[p + 1][0]:
116
            # ground truth segmentation:
117
            p_slices_mask = np.array(
118
                [im_m[0] for im_id, im_m in name_test_pred[i : p + 1]]
119
            )
120
            # predicted segmentation:
121
            p_slices_pred = np.array(
122
                [im_m[1] for im_id, im_m in name_test_pred[i : p + 1]]
123
            )
124
125
            patient_ids.append(p_id)
126
            dc_values.append(dice_coefficient(p_slices_pred, p_slices_mask))
127
            print(p_id + ":\t" + str(dc_values[-1]))
128
129
            i = p + 1
130
131
    return dc_values, patient_ids
132
133
134
def dice_coefficient(prediction, ground_truth):
135
    prediction = np.round(prediction).astype(int)
136
    ground_truth = np.round(ground_truth).astype(int)
137
    return (
138
        np.sum(prediction[ground_truth == 1])
139
        * 2.0
140
        / (np.sum(prediction) + np.sum(ground_truth))
141
    )
142
143
144
def gray2rgb(im):
145
    w, h = im.shape
146
    ret = np.empty((w, h, 3), dtype=np.uint8)
147
    ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = im
148
    return ret
149
150
151
def plot_dc(labels, values):
152
    y_pos = np.arange(len(labels))
153
154
    fig = plt.figure(figsize=(12, 8))
155
    plt.barh(y_pos, values, align="center", alpha=0.5)
156
    plt.yticks(y_pos, labels)
157
    plt.xticks(np.arange(0.5, 1.0, 0.05))
158
    plt.xlabel("Dice coefficient", fontsize="x-large")
159
    plt.axes().xaxis.grid(color="black", linestyle="-", linewidth=0.5)
160
    axes = plt.gca()
161
    axes.set_xlim([0.5, 1.0])
162
    plt.tight_layout()
163
    axes.axvline(np.mean(values), color="green", linewidth=2)
164
165
    plt.savefig("DSC.png", bbox_inches="tight")
166
    plt.close(fig)
167
168
169
if __name__ == "__main__":
170
171
    config = tf.ConfigProto()
172
    config.gpu_options.allow_growth = True
173
    sess = tf.Session(config=config)
174
    K.set_session(sess)
175
176
    if len(sys.argv) > 1:
177
        gpu = sys.argv[1]
178
    device = "/gpu:" + gpu
179
180
    with tf.device(device):
181
        imgs_mask_test, imgs_mask_pred, names_test = predict()
182
        values, labels = evaluate(imgs_mask_test, imgs_mask_pred, names_test)
183
184
    print("\nAverage DSC: " + str(np.mean(values)))
185
186
    # plot results
187
    plot_dc(labels, values)