|
a |
|
b/model_utils.py |
|
|
1 |
import os |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
import matplotlib.pyplot as plt |
|
|
5 |
from sklearn.metrics import log_loss, confusion_matrix |
|
|
6 |
import tensorflow as tf |
|
|
7 |
|
|
|
8 |
import config |
|
|
9 |
from utils import store_to_csv, read_csv |
|
|
10 |
|
|
|
11 |
# Network Input Parameters |
|
|
12 |
n_x = config.IMAGE_PXL_SIZE_X |
|
|
13 |
n_y = config.IMAGE_PXL_SIZE_Y |
|
|
14 |
n_z = config.SLICES |
|
|
15 |
num_channels = config.NUM_CHANNELS |
|
|
16 |
|
|
|
17 |
# tf Graph input |
|
|
18 |
x = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, n_z, n_x, n_y, num_channels), |
|
|
19 |
name='input') |
|
|
20 |
y = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE,), name='label') |
|
|
21 |
keep_prob = tf.placeholder(tf.float32, name='dropout') #dropout (keep probability) |
|
|
22 |
|
|
|
23 |
input_img = tf.placeholder(tf.float32, |
|
|
24 |
shape=(1, config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y)) |
|
|
25 |
# Reshape input picture, first dimension is kept to be able to support batches |
|
|
26 |
reshape_op = tf.reshape(input_img, |
|
|
27 |
shape=(-1, config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y, 1)) |
|
|
28 |
|
|
|
29 |
input_test_img = tf.placeholder(tf.float32, |
|
|
30 |
shape=(config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y)) |
|
|
31 |
# Reshape test input picture |
|
|
32 |
reshape_test_op = tf.reshape(input_test_img, |
|
|
33 |
shape=(-1, config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y, 1)) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def store_error_plots(validation_err, train_err): |
|
|
37 |
try: |
|
|
38 |
plt.plot(validation_err) |
|
|
39 |
plt.savefig("validation_errors.png") |
|
|
40 |
|
|
|
41 |
plt.plot(train_err) |
|
|
42 |
plt.savefig("train_errors.png") |
|
|
43 |
except Exception as e: |
|
|
44 |
print("Drawing errors failed with: {}".format(e)) |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def high_error_increase(errors, |
|
|
48 |
current, |
|
|
49 |
least_count=3, |
|
|
50 |
incr_threshold=0.1): |
|
|
51 |
if len(errors) < least_count: |
|
|
52 |
return False |
|
|
53 |
|
|
|
54 |
return any(current - x >= incr_threshold |
|
|
55 |
for x in errors) |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def get_max_prob(output, ind_value): |
|
|
59 |
max_prob = output[ind_value] |
|
|
60 |
if ind_value == config.NO_CANCER_CLS: |
|
|
61 |
max_prob = 1.0 - max_prob |
|
|
62 |
|
|
|
63 |
return max_prob |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def accuracy(predictions, labels): |
|
|
67 |
return (100 * np.sum(np.argmax(predictions, 1) == labels) |
|
|
68 |
/ predictions.shape[0]) |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def evaluate_log_loss(predictions, target_labels): |
|
|
72 |
return log_loss(target_labels, predictions, labels=[0, 1]) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def get_confusion_matrix(target_labels, predictions, labels=[0, 1]): |
|
|
76 |
predicted_labels = np.argmax(predictions, 1) |
|
|
77 |
return confusion_matrix(target_labels, predicted_labels, labels) |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def display_confusion_matrix_info(target_labels, predictions, labels=[0, 1]): |
|
|
81 |
matrix = get_confusion_matrix(target_labels, predictions, labels) |
|
|
82 |
print("True negatives count: ", matrix[0][0]) |
|
|
83 |
print("False negatives count: ", matrix[1][0]) |
|
|
84 |
print("True positives count: ", matrix[1][1]) |
|
|
85 |
print("False positives count: ", matrix[0][1]) |
|
|
86 |
|
|
|
87 |
return matrix |
|
|
88 |
|
|
|
89 |
def get_sensitivity(confusion_matrix): |
|
|
90 |
true_positives = confusion_matrix[1][1] |
|
|
91 |
false_negatives = confusion_matrix[1][0] |
|
|
92 |
|
|
|
93 |
return true_positives / float(true_positives + false_negatives) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
def get_specificity(confusion_matrix): |
|
|
97 |
true_negatives = confusion_matrix[0][0] |
|
|
98 |
false_positives = confusion_matrix[0][1] |
|
|
99 |
|
|
|
100 |
return true_negatives / float(true_negatives + false_positives) |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
def calculate_conv_output_size(x, y, z, strides, filters, paddings, last_depth): |
|
|
104 |
# Currently axes are transposed [z, x, y] |
|
|
105 |
for i, stride in enumerate(strides): |
|
|
106 |
if paddings[i] == 'VALID': |
|
|
107 |
f = filters[i] |
|
|
108 |
x = np.ceil(np.float((x - f[1] + 1) / float(stride[1]))) |
|
|
109 |
y = np.ceil(np.float((y - f[2] + 1) / float(stride[2]))) |
|
|
110 |
z = np.ceil(np.float((z - f[0] + 1) / float(stride[0]))) |
|
|
111 |
else: |
|
|
112 |
x = np.ceil(float(x) / float(stride[1])) |
|
|
113 |
y = np.ceil(float(y) / float(stride[2])) |
|
|
114 |
z = np.ceil(float(z) / float(stride[0])) |
|
|
115 |
|
|
|
116 |
return int(x * y * z * last_depth) |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
def model_store_path(store_dir, step): |
|
|
120 |
return os.path.join(store_dir, |
|
|
121 |
'model_{}.ckpt'.format(step)) |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
def validate_data_loaded(images_batch, images_labels): |
|
|
125 |
if not (len(images_labels) and len(images_labels)): |
|
|
126 |
print("Please check you configurations, unable to laod the images...") |
|
|
127 |
return False |
|
|
128 |
return True |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
def evaluate_validation_set(sess, |
|
|
132 |
validation_set, |
|
|
133 |
valid_prediction, |
|
|
134 |
feed_data_key, |
|
|
135 |
batch_size): |
|
|
136 |
validation_pred = [] |
|
|
137 |
validation_labels = [] |
|
|
138 |
|
|
|
139 |
index = 0 |
|
|
140 |
while index < validation_set.num_samples: |
|
|
141 |
validation_batch, validation_label = validation_set.next_batch(batch_size) |
|
|
142 |
if not validate_data_loaded(validation_batch, validation_label): |
|
|
143 |
return (0, 0, 0, 0) |
|
|
144 |
reshaped = sess.run(reshape_op, feed_dict={input_img: np.stack(validation_batch)}) |
|
|
145 |
batch_pred = sess.run(valid_prediction, |
|
|
146 |
feed_dict={feed_data_key: reshaped, keep_prob: 1.}) |
|
|
147 |
|
|
|
148 |
validation_pred.extend(batch_pred) |
|
|
149 |
validation_labels.extend(validation_label) |
|
|
150 |
index += batch_size |
|
|
151 |
|
|
|
152 |
validation_acc = accuracy(np.stack(validation_pred), |
|
|
153 |
np.stack(validation_labels)) |
|
|
154 |
validation_log_loss = evaluate_log_loss(validation_pred, |
|
|
155 |
validation_labels) |
|
|
156 |
|
|
|
157 |
confusion_matrix = display_confusion_matrix_info(validation_labels, validation_pred) |
|
|
158 |
sensitivity = get_sensitivity(confusion_matrix) |
|
|
159 |
specificity = get_specificity(confusion_matrix) |
|
|
160 |
|
|
|
161 |
return (validation_acc, validation_log_loss, sensitivity, specificity) |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
def evaluate_test_set(sess, |
|
|
165 |
test_set, |
|
|
166 |
test_prediction, |
|
|
167 |
feed_data_key, |
|
|
168 |
export_csv=True): |
|
|
169 |
i = 0 |
|
|
170 |
patients, probs = [], [] |
|
|
171 |
|
|
|
172 |
try: |
|
|
173 |
while i < test_set.num_samples: |
|
|
174 |
patient, test_img = test_set.next_patient() |
|
|
175 |
# returns index of column with highest probability |
|
|
176 |
# [first class=no cancer=0, second class=cancer=1] |
|
|
177 |
if len(test_img): |
|
|
178 |
test_img = sess.run(reshape_test_op, feed_dict={input_test_img: test_img}) |
|
|
179 |
i += 1 |
|
|
180 |
patients.append(patient) |
|
|
181 |
output = sess.run(test_prediction, |
|
|
182 |
feed_dict={feed_data_key: test_img, keep_prob: 1.}) |
|
|
183 |
max_ind_f = tf.argmax(output, 1) |
|
|
184 |
ind_value = sess.run(max_ind_f) |
|
|
185 |
max_prob = get_max_prob(output[0], ind_value[0]) |
|
|
186 |
probs.append(max_prob) |
|
|
187 |
|
|
|
188 |
print("Output {} for patient with id {}, predicted output {}.".format( |
|
|
189 |
max_prob, patient, output[0])) |
|
|
190 |
|
|
|
191 |
else: |
|
|
192 |
print("Corrupted test image, incorrect shape for patient {}".format( |
|
|
193 |
patient)) |
|
|
194 |
|
|
|
195 |
if export_csv: |
|
|
196 |
store_to_csv(patients, probs, config.SOLUTION_FILE_PATH) |
|
|
197 |
except Exception as e: |
|
|
198 |
print("Storing results failed with: {} Probably solution file is incomplete.".format(e)) |
|
|
199 |
|
|
|
200 |
|
|
|
201 |
def evaluate_solution(sample_solution, with_merged_report=True): |
|
|
202 |
true_labels = read_csv(config.REAL_SOLUTION_CSV) |
|
|
203 |
predictions = read_csv(sample_solution) |
|
|
204 |
patients = true_labels.index.values |
|
|
205 |
|
|
|
206 |
probs, labels, probs_cls = [], [], [] |
|
|
207 |
for patient in patients: |
|
|
208 |
prob = predictions.get_value(patient, config.COLUMN_NAME) |
|
|
209 |
probs.append(prob) |
|
|
210 |
probs_cls.append([1.0 - prob, prob]) |
|
|
211 |
labels.append(true_labels.get_value(patient, config.COLUMN_NAME)) |
|
|
212 |
|
|
|
213 |
probs_cls = np.array(probs_cls) |
|
|
214 |
log_loss_err = evaluate_log_loss(probs_cls, labels) |
|
|
215 |
acc = accuracy(probs_cls, np.array(labels)) |
|
|
216 |
|
|
|
217 |
confusion_matrix = display_confusion_matrix_info(labels, probs_cls) |
|
|
218 |
sensitivity = get_sensitivity(confusion_matrix) |
|
|
219 |
specificity = get_specificity(confusion_matrix) |
|
|
220 |
|
|
|
221 |
print("Log loss: ", round(log_loss_err, 5)) |
|
|
222 |
print("Accuracy: %.1f%%" % acc) |
|
|
223 |
print("Sensitivity: ", round(sensitivity, 5)) |
|
|
224 |
print("Specificity: ", round(specificity, 5)) |
|
|
225 |
|
|
|
226 |
if with_merged_report: |
|
|
227 |
df = pd.DataFrame(data={'prediction': probs, 'label': labels}, |
|
|
228 |
columns=['prediction', 'label'], |
|
|
229 |
index=true_labels.index) |
|
|
230 |
df.to_csv('report_{}'.format(os.path.basename(sample_solution))) |
|
|
231 |
|
|
|
232 |
return (log_loss_err, acc, sensitivity, specificity) |