a b/predict.py
1
from __future__ import absolute_import, division, print_function
2
3
from os import environ, getcwd
4
from os.path import join
5
6
import shutil
7
import re
8
import os
9
import argparse
10
import keras
11
import numpy as np
12
import pandas as pd
13
import sklearn as skl
14
import tensorflow as tf
15
from keras.applications.vgg19 import VGG19
16
from keras.applications import DenseNet169, InceptionResNetV2, DenseNet201
17
from keras.applications import NASNetMobile
18
from keras.layers import Dense, GlobalAveragePooling2D
19
from keras.metrics import binary_accuracy, binary_crossentropy, kappa_error
20
from keras.models import Model
21
from keras.optimizers import Adam
22
from keras.preprocessing.image import ImageDataGenerator
23
from custom_layers import *
24
from mura import Mura
25
26
pd.set_option('display.max_rows', 20)
27
pd.set_option('precision', 4)
28
np.set_printoptions(precision=4)
29
30
environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Shut up tensorflow!
31
print("tf : {}".format(tf.__version__))
32
print("keras : {}".format(keras.__version__))
33
print("numpy : {}".format(np.__version__))
34
print("pandas : {}".format(pd.__version__))
35
print("sklearn : {}".format(skl.__version__))
36
37
# Hyper-parameters / Globals
38
BATCH_SIZE = 4  # tweak to your GPUs capacity
39
IMG_HEIGHT = 420  # ResNetInceptionv2 & Xception like 299, ResNet50/VGG/Inception 224, NASM 331
40
IMG_WIDTH = IMG_HEIGHT
41
CHANNELS = 3
42
DIMS = (IMG_HEIGHT, IMG_WIDTH, CHANNELS)  # blame theano
43
MODEL_TO_EVAL1 = './models/DenseNet169_420_HUMERUS.hdf5'
44
MODEL_TO_EVAL2 = './models/DenseNet169_420_HAND.hdf5'
45
MODEL_TO_EVAL3 = './models/DenseNet169_420_FINGER.hdf5'
46
MODEL_TO_EVAL4 = './models/DenseNet169_420_FOREARM.hdf5'
47
MODEL_TO_EVAL5 = './models/DenseNet169_420_ELBOW.hdf5'
48
MODEL_TO_EVAL6 = './models/DenseNet169_420_SHOULDER.hdf5'
49
MODEL_TO_EVAL7 = './models/DenseNet169_420_WRIST.hdf5'
50
MODEL_TO_EVAL8 = './models/DenseNet169_420_NEW_HIST.hdf5'
51
DATA_DIR = 'data_val/'
52
EVAL_CSV = 'valid.csv'
53
EVAL_DIR = 'data/val/'
54
55
parser = argparse.ArgumentParser(description='Input Path')
56
parser.add_argument('input_filename',default='valid_image_paths.csv', type=str)
57
parser.add_argument('output_path', default='prediction.csv', type=str)
58
proc_data_dir = join(os.getcwd(), 'data/val/')
59
proc_train_dir = join(proc_data_dir, 'train')
60
proc_val_dir = join(proc_data_dir, 'val')
61
62
63
class ImageString(object):
64
    _patient_re = re.compile(r'patient(\d+)')
65
    _study_re = re.compile(r'study(\d+)')
66
    _image_re = re.compile(r'image(\d+)')
67
    _study_type_re = re.compile(r'XR_(\w+)')
68
69
    def __init__(self, img_filename):
70
71
        self.img_filename = img_filename
72
        self.patient = self._parse_patient()
73
        self.study = self._parse_study()
74
        self.image_num = self._parse_image()
75
        self.study_type = self._parse_study_type()
76
        self.image = self._parse_image()
77
        self.normal = self._parse_normal()
78
        self.valid = self._parse_valid()
79
80
81
    def flat_file_name(self):
82
        return "{}_{}_patient{}_study{}_{}_image{}.png".format(self.valid,  self.study_type, self.patient, self.study,
83
                                                            self.normal, self.image)
84
85
    def _parse_patient(self):
86
        return int(self._patient_re.search(self.img_filename).group(1))
87
88
    def _parse_study(self):
89
        return int(self._study_re.search(self.img_filename).group(1))
90
91
    def _parse_image(self):
92
        return int(self._image_re.search(self.img_filename).group(1))
93
94
    def _parse_study_type(self):
95
        return self._study_type_re.search(self.img_filename).group(1)
96
97
    def _parse_normal(self):
98
        return "normal" if ("negative" in self.img_filename) else "abnormal"
99
100
    def _parse_normal_label(self):
101
        return 1 if("negative" in self.img_filename) else 0
102
103
    def _parse_valid(self):
104
        return "valid" if ("valid" in self.img_filename) else "test"
105
106
def preprocess_img(img):
107
    # Histogram normalization in v channel
108
    hsv = color.rgb2hsv(img)
109
    hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
110
    img = color.hsv2rgb(hsv)
111
112
    # central square crop
113
    min_side = min(img.shape[:-1])
114
    centre = img.shape[0] // 2, img.shape[1] // 2
115
    img = img[centre[0] - min_side // 2:centre[0] + min_side // 2,
116
              centre[1] - min_side // 2:centre[1] + min_side // 2,
117
              :]
118
119
    # rescale to standard size
120
    img = transform.resize(img, (IMG_SIZE, IMG_SIZE))
121
122
    # roll color axis to axis 0
123
    img = np.rollaxis(img, -1)
124
125
    return img
126
127
def eval(args=None):
128
129
    args= parser.parse_args()
130
131
    # load up our csv with validation factors
132
    data_dir = join(getcwd(), DATA_DIR)
133
    eval_csv = join(data_dir, EVAL_CSV)
134
135
    true_labels=[]
136
137
    ###########################################
138
    df = pd.read_csv(args.input_filename, names=['img', 'label'], header=None)
139
    samples = [tuple(x) for x in df.values]
140
 #   for img, label in samples:
141
 #       #assert ("negative" in img) is (label is 0)
142
 #       enc = ImageString(img)
143
 #       true_labels.append(enc._parse_normal_label())
144
 #       cat_dir = join(proc_val_dir, enc.normal)
145
 #       if not os.path.exists(cat_dir):
146
 #           os.makedirs(cat_dir)
147
 #       shutil.copy2(enc.img_filename, join(cat_dir, enc.flat_file_name()))
148
149
150
    ###########################################
151
152
    eval_datagen = ImageDataGenerator(rescale=1./255
153
#                                    , histogram_equalization=True
154
                                      )
155
    eval_generator = eval_datagen.flow_from_directory(
156
         EVAL_DIR, class_mode='binary', shuffle=False,target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE)
157
    n_samples = eval_generator.samples
158
    base_model = DenseNet169(input_shape=DIMS, weights='imagenet', include_top=False)  #weights='imagenet'
159
    x = base_model.output
160
    x = GlobalAveragePooling2D(name='avg_pool')(x)  # comment for RESNET
161
 #   x = WildcatPool2d()(x)
162
163
    x = Dense(1, activation='sigmoid', name='predictions')(x)
164
    model = Model(inputs=base_model.input, outputs=x)
165
    model.load_weights(MODEL_TO_EVAL8)
166
    model.compile(optimizer=Adam(lr=1e-3)
167
                  , loss=binary_crossentropy
168
#                  , loss=kappa_error
169
                  , metrics=['binary_accuracy'])
170
    score, acc = model.evaluate_generator(eval_generator, n_samples / BATCH_SIZE)
171
    print(model.metrics_names)
172
    print('==> Metrics with eval')
173
    print("loss :{:0.4f} \t Accuracy:{:0.4f}".format(score, acc))
174
    y_pred = model.predict_generator(eval_generator, n_samples / BATCH_SIZE)
175
176
#    print(y_pred)
177
#    df_filenames = pd.Series(np.array(eval_generator.filenames), name='filenames')
178
#    df_classes   = pd.Series(np.array(y_pred), name='classes')
179
180
#    prediction_data = pd.concat([df_filenames, df_classes,])
181
#    prediction_data.to_csv(args.output_path + "/prediction.csv")
182
183
    mura = Mura(eval_generator.filenames, y_true = eval_generator.classes, y_pred1=y_pred, y_pred2=y_pred, y_pred3=y_pred, y_pred4= y_pred, y_pred5= y_pred, output_path= args.output_path)
184
    print(mura.metrics_by_encounter())
185
186
187
if __name__ == '__main__':
188
    eval()