Diff of /model_utils.py [000000] .. [4f54f1]

Switch to unified view

a b/model_utils.py
1
import os
2
import pandas as pd
3
import numpy as np
4
import matplotlib.pyplot as plt
5
from sklearn.metrics import log_loss, confusion_matrix
6
import tensorflow as tf
7
8
import config
9
from utils import store_to_csv, read_csv
10
11
# Network Input Parameters
12
n_x = config.IMAGE_PXL_SIZE_X
13
n_y = config.IMAGE_PXL_SIZE_Y
14
n_z = config.SLICES
15
num_channels = config.NUM_CHANNELS
16
17
# tf Graph input
18
x = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, n_z, n_x, n_y, num_channels), 
19
    name='input')
20
y = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE,), name='label')
21
keep_prob = tf.placeholder(tf.float32, name='dropout') #dropout (keep probability)
22
23
input_img = tf.placeholder(tf.float32, 
24
    shape=(1, config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y))
25
# Reshape input picture, first dimension is kept to be able to support batches
26
reshape_op = tf.reshape(input_img, 
27
    shape=(-1, config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y, 1))
28
29
input_test_img = tf.placeholder(tf.float32, 
30
    shape=(config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y))
31
# Reshape test input picture
32
reshape_test_op = tf.reshape(input_test_img, 
33
    shape=(-1, config.SLICES, config.IMAGE_PXL_SIZE_X, config.IMAGE_PXL_SIZE_Y, 1))
34
35
36
def store_error_plots(validation_err, train_err):
37
    try:
38
        plt.plot(validation_err)
39
        plt.savefig("validation_errors.png")
40
41
        plt.plot(train_err)
42
        plt.savefig("train_errors.png")
43
    except Exception as e:
44
        print("Drawing errors failed with: {}".format(e))
45
46
47
def high_error_increase(errors, 
48
                        current, 
49
                        least_count=3, 
50
                        incr_threshold=0.1):
51
    if len(errors) < least_count:
52
        return False
53
54
    return any(current - x >= incr_threshold 
55
        for x in errors)
56
57
58
def get_max_prob(output, ind_value):
59
    max_prob = output[ind_value]
60
    if ind_value == config.NO_CANCER_CLS:
61
        max_prob = 1.0 - max_prob
62
63
    return max_prob
64
65
66
def accuracy(predictions, labels):
67
    return (100 * np.sum(np.argmax(predictions, 1) == labels) 
68
        / predictions.shape[0])
69
70
71
def evaluate_log_loss(predictions, target_labels):
72
    return log_loss(target_labels, predictions, labels=[0, 1])
73
74
75
def get_confusion_matrix(target_labels, predictions, labels=[0, 1]):
76
    predicted_labels = np.argmax(predictions, 1)
77
    return confusion_matrix(target_labels, predicted_labels, labels)
78
79
80
def display_confusion_matrix_info(target_labels, predictions, labels=[0, 1]):
81
    matrix = get_confusion_matrix(target_labels, predictions, labels)
82
    print("True negatives count: ", matrix[0][0])
83
    print("False negatives count: ", matrix[1][0])
84
    print("True positives count: ", matrix[1][1])
85
    print("False positives count: ", matrix[0][1])
86
87
    return matrix
88
89
def get_sensitivity(confusion_matrix):
90
    true_positives = confusion_matrix[1][1]
91
    false_negatives = confusion_matrix[1][0]
92
93
    return true_positives / float(true_positives + false_negatives)
94
95
96
def get_specificity(confusion_matrix):
97
    true_negatives = confusion_matrix[0][0]
98
    false_positives = confusion_matrix[0][1]
99
100
    return true_negatives / float(true_negatives + false_positives)
101
102
103
def calculate_conv_output_size(x, y, z, strides, filters, paddings, last_depth):
104
    # Currently axes are transposed [z, x, y]
105
    for i, stride in enumerate(strides):
106
        if paddings[i] == 'VALID':
107
            f = filters[i]
108
            x = np.ceil(np.float((x - f[1] + 1) / float(stride[1])))
109
            y = np.ceil(np.float((y - f[2] + 1) / float(stride[2])))
110
            z = np.ceil(np.float((z - f[0] + 1) / float(stride[0])))
111
        else:
112
            x = np.ceil(float(x) / float(stride[1]))
113
            y = np.ceil(float(y) / float(stride[2]))
114
            z = np.ceil(float(z) / float(stride[0]))
115
116
    return int(x * y * z * last_depth)
117
118
119
def model_store_path(store_dir, step):
120
    return os.path.join(store_dir, 
121
        'model_{}.ckpt'.format(step))
122
123
124
def validate_data_loaded(images_batch, images_labels):
125
    if not (len(images_labels) and len(images_labels)):
126
        print("Please check you configurations, unable to laod the images...")
127
        return False
128
    return True
129
130
131
def evaluate_validation_set(sess, 
132
                            validation_set, 
133
                            valid_prediction, 
134
                            feed_data_key, 
135
                            batch_size):
136
    validation_pred = []
137
    validation_labels = []
138
139
    index = 0
140
    while index < validation_set.num_samples:
141
        validation_batch, validation_label = validation_set.next_batch(batch_size)
142
        if not validate_data_loaded(validation_batch, validation_label):
143
            return (0, 0, 0, 0)
144
        reshaped = sess.run(reshape_op, feed_dict={input_img: np.stack(validation_batch)})
145
        batch_pred = sess.run(valid_prediction, 
146
            feed_dict={feed_data_key: reshaped, keep_prob: 1.})
147
       
148
        validation_pred.extend(batch_pred)
149
        validation_labels.extend(validation_label)
150
        index += batch_size
151
152
    validation_acc = accuracy(np.stack(validation_pred), 
153
        np.stack(validation_labels))
154
    validation_log_loss = evaluate_log_loss(validation_pred, 
155
                                            validation_labels)
156
157
    confusion_matrix = display_confusion_matrix_info(validation_labels, validation_pred)
158
    sensitivity = get_sensitivity(confusion_matrix)
159
    specificity = get_specificity(confusion_matrix)
160
161
    return (validation_acc, validation_log_loss, sensitivity, specificity)
162
163
164
def evaluate_test_set(sess, 
165
                      test_set,
166
                      test_prediction,
167
                      feed_data_key,
168
                      export_csv=True):
169
    i = 0
170
    patients, probs = [], []
171
172
    try:
173
        while i < test_set.num_samples:
174
            patient, test_img = test_set.next_patient()
175
            # returns index of column with highest probability
176
            # [first class=no cancer=0, second class=cancer=1]
177
            if len(test_img):
178
                test_img = sess.run(reshape_test_op, feed_dict={input_test_img: test_img})
179
                i += 1
180
                patients.append(patient)
181
                output = sess.run(test_prediction, 
182
                    feed_dict={feed_data_key: test_img, keep_prob: 1.})
183
                max_ind_f = tf.argmax(output, 1)
184
                ind_value = sess.run(max_ind_f)
185
                max_prob = get_max_prob(output[0], ind_value[0])
186
                probs.append(max_prob)
187
188
                print("Output {} for patient with id {}, predicted output {}.".format(
189
                    max_prob, patient, output[0]))
190
191
            else:
192
                print("Corrupted test image, incorrect shape for patient {}".format(
193
                    patient))
194
195
        if export_csv:
196
            store_to_csv(patients, probs, config.SOLUTION_FILE_PATH)
197
    except Exception as e:
198
        print("Storing results failed with: {} Probably solution file is incomplete.".format(e))
199
200
201
def evaluate_solution(sample_solution, with_merged_report=True):
202
    true_labels = read_csv(config.REAL_SOLUTION_CSV)
203
    predictions = read_csv(sample_solution)
204
    patients = true_labels.index.values
205
206
    probs, labels, probs_cls = [], [], []
207
    for patient in patients:
208
        prob = predictions.get_value(patient, config.COLUMN_NAME)
209
        probs.append(prob)
210
        probs_cls.append([1.0 - prob, prob])
211
        labels.append(true_labels.get_value(patient, config.COLUMN_NAME))
212
    
213
    probs_cls = np.array(probs_cls)
214
    log_loss_err = evaluate_log_loss(probs_cls, labels)
215
    acc = accuracy(probs_cls, np.array(labels))
216
217
    confusion_matrix = display_confusion_matrix_info(labels, probs_cls)
218
    sensitivity = get_sensitivity(confusion_matrix)
219
    specificity = get_specificity(confusion_matrix)
220
221
    print("Log loss: ", round(log_loss_err, 5))
222
    print("Accuracy: %.1f%%" % acc)
223
    print("Sensitivity: ", round(sensitivity, 5))
224
    print("Specificity: ", round(specificity, 5))
225
226
    if with_merged_report:
227
        df = pd.DataFrame(data={'prediction': probs, 'label': labels},
228
                          columns=['prediction', 'label'],
229
                          index=true_labels.index)
230
        df.to_csv('report_{}'.format(os.path.basename(sample_solution)))
231
232
    return (log_loss_err, acc, sensitivity, specificity)