Switch to unified view

a b/keras_CNN/keras_evaluate.py
1
"""
2
Evaluate a Keras model on a fresh dataset.
3
"""
4
5
from __future__ import print_function
6
from keras.models import Model
7
from keras.optimizers import SGD, Adagrad, Adadelta, RMSprop
8
from keras.utils import np_utils
9
from keras.models import load_model
10
from keras.models import model_from_json
11
from keras import backend as K
12
import sys, os, numpy as np
13
from load_tumor_image_data import *
14
15
nb_classes    = 2
16
batch_size    = 64
17
18
def main():
19
    args = get_args()
20
    model_file   = args['model_file']
21
    data_file    = args['data_file']
22
    weights_file = None
23
    split_model  = False
24
    if args['model_weights'] != '':
25
        weights_file = args['model_weights']
26
        split_model = True
27
    elif os.path.splitext(model_file)[1] == '.json':
28
        weights_file = os.path.splitext(model_file)[0] + '.weights.hd5'
29
        split_model = True
30
    window_normalize = args['window_normalize']
31
   
32
    # Give some feedback on settings
33
    if not args['normalize'] and (window_normalize == False):
34
        print("Using raw images.")
35
    if window_normalize:
36
        print("Using window normalization.")
37
    # Load the data file
38
    (X,y) = load_all_data(data_file, normalize=args['normalize'], window_normalize=window_normalize)
39
    # Feedback on the data
40
    print("X shape: {} ; X[0] shape: {}  X[0][2] shape: {}".format(X.shape, X[0].shape, X[0][2].shape))
41
    img_rows, img_cols = X[0][2].shape
42
    print('img_rows: {0}, img_cols: {1}'.format(img_rows, img_cols))
43
    print('X shape:', X.shape)
44
    print('{0} samples ({1} (+), {2} (-))'.format(X.shape[0], y.sum(), len(y)-y.sum()))
45
46
    # convert class vectors to binary class matrices
47
    ground_truth = y
48
    y = np_utils.to_categorical(y, nb_classes)
49
50
    # Check validity of model file:
51
    model = None
52
    if split_model:
53
        print('Loading model design {0}\nand weights {1}'.format(model_file, weights_file))
54
        with open(model_file, 'rU') as json_file:
55
            model = model_from_json(json_file.read())
56
        model.load_weights(weights_file)
57
    else:
58
        print('Loading full model file {0}'.format(model_file))
59
        model = load_model(model_file)
60
    
61
    if args['show_layout']:
62
        print('Network Layout:')
63
        model.summary()
64
        print('\n\n')
65
    
66
    y_pred = model.predict(X, batch_size=batch_size, verbose=1)
67
    result = quantify_results(y_pred[:,1], ground_truth, auc=True)
68
69
    if args['list_cases']:
70
        metadata = load_metadata(data_file)
71
        print('')
72
        print_case_table(y_pred, ground_truth, metadata)
73
        print('\n\n')
74
75
    if args['expression_file'] != None:
76
        print('\nSaving expression vectors in {0}.'.format(args['expression_file']))
77
        import csv
78
        feature_layer = find_feature_layer(model, index=int(args['expression_layer_index']))
79
        print("Getting output at layer index {0}. Layer output shape: {1}".format(feature_layer, model.layers[feature_layer].output_shape))
80
        expression    = get_output_from_layer(model, X, layer_index=feature_layer)
81
        with open(args['expression_file'], 'wb') as csvfile:
82
            writer = csv.writer(csvfile, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
83
            for vec in expression:
84
                writer.writerow(vec)
85
86
    if args['predictions_file'] != None:
87
        # predictions are just expression at the last layer, reduced to a single floating-point value if
88
        # the number of classes is 2.
89
        print('\nSaving prediction values in {0}.'.format(args['predictions_file']))
90
        import csv
91
        with open(args['predictions_file'], 'wb') as csvfile:
92
            writer = csv.writer(csvfile, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
93
            for vec in y_pred:
94
                try:
95
                    if len(vec) == 2:
96
                        vec = [vec[1]]
97
                except:
98
                    vec = [vec]
99
                writer.writerow(vec)        
100
101
    print('(+) - Total: {0} ; True: {1} ; False {2}'.format(result['npos'],result['tp'],result['fp']))
102
    print('(-) - Total: {0} ; True: {1} ; False {2}'.format(result['nneg'],result['tn'],result['fn']))
103
    print('auc: {0} acc: {1}, sensitivity: {2}, specificity: {3}, precision: {4}'.format(result['auc'], result['acc'], result['sens'],result['spec'], result['prec']))
104
    return 0
105
106
def get_output_from_layer(model, X, layer_index=None, train_mode=False, n_categories=2):
107
    """
108
    get the output from a specific layer given the input; defaults to the last
109
    layer before a reduction to classes (<=n_categories)
110
    """
111
    mode = 0 if not train_mode else 1
112
    if layer_index == None:
113
        layer_index = find_feature_layer(model, n_categories=n_categories)
114
     
115
    get_nth_layer_output = K.function([model.layers[0].input, K.learning_phase()],
116
                                      [model.layers[layer_index].output])
117
118
    layer_output = get_nth_layer_output([X, mode])[0]
119
120
    return layer_output
121
122
123
def find_last_feature_layer(model, n_categories=2):
124
    unwanted_layers  = ['Dropout', 'Pooling']
125
    last_layer_index = len(model.layers) - 1
126
    last_layer       = model.layers[last_layer_index]
127
    while last_layer.output_shape[-1] <= n_categories \
128
          or any(ltype in str(type(last_layer)) for ltype in unwanted_layers):
129
          last_layer_index -= 1
130
          last_layer = model.layers[last_layer_index]
131
    return last_layer_index
132
133
def find_feature_layer(model, index=-2):
134
    unwanted_layers     = ['Dropout', 'Pooling']
135
    direction           = np.sign(index)
136
    i_orig              = index
137
    index               = abs(index)
138
    current_layer_index = len(model.layers) - 1
139
    current_layer       = model.layers[current_layer_index]
140
    if direction >= 0:
141
        current_layer       = model.layers[0]
142
        current_layer_index = 0
143
    while (direction != 0 and index >= 0) or any(ltype in str(type(current_layer)) for ltype in unwanted_layers):
144
          current_layer_index += direction
145
          current_layer = model.layers[current_layer_index]
146
          index -= 1
147
    return current_layer_index
148
149
150
def quantify_results(predictions, ground_truth, auc=False):
151
    fp = fn = tp = tn = npos = nneg = 0
152
    total = len(predictions)
153
    if auc:
154
        from sklearn import metrics
155
        auc_value   = round(metrics.roc_auc_score(np.ravel(ground_truth), np.ravel(predictions)), 3)
156
        predictions = np.round(predictions)
157
158
    for idx, predicted in enumerate(predictions):
159
        if is_positive(ground_truth[idx]):
160
            npos += 1
161
        else:
162
            nneg += 1
163
        if is_positive(predicted) == is_positive(ground_truth[idx]):
164
            if is_positive(ground_truth[idx]):
165
                tp += 1
166
            else:
167
                tn += 1
168
        else:
169
            if is_positive(predicted):
170
                fp += 1
171
            else:
172
                fn += 1
173
    epsilon = 1e-20 # used to avoid divide-by-zero in rare cases where all example are seen as a single class
174
    acc  = round((tp + tn) / float(total), 3)
175
    sens = round(tp  / (float(npos)    + epsilon), 3)
176
    spec = round(tn  / (float(nneg)    + epsilon), 3)
177
    ppv  = round(tp  / (float(tp + fp) + epsilon), 3)
178
    npv  = round(tn  / (float(tn + fn) + epsilon), 3)
179
    result = {
180
        'fp': fp,
181
        'tp': tp,
182
        'fn': fn,
183
        'tn': tn,
184
        'total': total,
185
        'tot': total,
186
        'npos': npos,
187
        'nneg': nneg,
188
        'acc': acc,
189
        'sens': sens,
190
        'tpr': sens,
191
        'spec': spec,
192
        'spc': spec,
193
        'tnr': spec,
194
        'ppv': ppv,
195
        'prec': ppv,
196
        'npv': npv
197
    }
198
    if auc:
199
        result['auc'] = auc_value
200
    return result
201
202
def print_case_table(y_pred, ground_truth, metadata={}):
203
    meta_keys = metadata.keys()
204
    headings  = ['true class', 'predicted', 'correct?']
205
    headings.extend(meta_keys)
206
    print('\t'.join(headings))
207
    for idx, prediction in enumerate(y_pred):
208
        results = [str(x) for x in [ground_truth[idx][0], prediction, str(ground_truth[idx][0]==prediction)]]
209
        for key in meta_keys:
210
            try:
211
                results.append(str(metadata[key][idx][0]))
212
            except TypeError:
213
                results.append(str(metadata[key][idx]))
214
        print('\t'.join(results))
215
216
def array_like(x):
217
    import collections
218
    return isinstance(x, collections.Sequence) or isinstance(x, np.ndarray)
219
220
def is_positive(cls):
221
    if array_like(cls) and len(cls) > 1:
222
        return cls[0] == 0
223
    else:
224
        return cls != 0 if not array_like(cls) else cls[0] != 0
225
226
def get_args():
227
    import argparse
228
    # construct the argument parser and parse the arguments
229
    ap = argparse.ArgumentParser(prog='{0}'.format(os.path.basename(sys.argv[0])))
230
    ap.add_argument("--list", dest="list_cases", action='store_true', help="List results case-by-case (verbose output).")
231
    ap.add_argument("--layout", dest="show_layout", action='store_true', help="Show network layout.")
232
    ap.add_argument("model_file", metavar='model-file' , help = "Keras model file (.hd5 or .json)")
233
    ap.add_argument("model_weights", metavar='model-weights', nargs='?', default="", 
234
        help = "Keras weights file (.hd5); optional if model file was .hd5 or if name is same as model file except for extension.")
235
    ap.add_argument("data_file", metavar='data-file' , help = "HDF5 file containing dataset to evaluate.")
236
    ap.add_argument("--raw", dest='normalize', action='store_false', help="Use raw images; no normalization.")
237
    ap.add_argument("--window", dest='window_normalize', action='store_true', help="Perform HU window normalization.")
238
    ap.add_argument("-x", dest = 'expression_file', required = False, 
239
        help = "If given, \"expression levels\" for each item in data file are saved here in CSV-compatible format.")
240
    ap.add_argument("-l", dest = 'expression_layer_index', required = False, default = -2,
241
        help = """Used in conjunction with '-x' to select a specific layer's expression output; positive values 
242
                  start from the beginning, negative values from the end.  Default is last layer before reducing 
243
                  to classes (-2); pooling and dropout layers do not count.""")
244
    ap.add_argument("-p", dest = 'predictions_file', required=False, 
245
        help = "If given, \"predictions\" (floating-point) for each item in data file are saved here in CSV-compatible format.")
246
    args = vars(ap.parse_args())
247
    return args
248
249
if __name__ == '__main__':
250
    # construct the argument parser and parse the arguments
251
    sys.exit(main())