|
a |
|
b/predict.py |
|
|
1 |
from __future__ import absolute_import, division, print_function |
|
|
2 |
|
|
|
3 |
from os import environ, getcwd |
|
|
4 |
from os.path import join |
|
|
5 |
|
|
|
6 |
import shutil |
|
|
7 |
import re |
|
|
8 |
import os |
|
|
9 |
import argparse |
|
|
10 |
import keras |
|
|
11 |
import numpy as np |
|
|
12 |
import pandas as pd |
|
|
13 |
import sklearn as skl |
|
|
14 |
import tensorflow as tf |
|
|
15 |
from keras.applications.vgg19 import VGG19 |
|
|
16 |
from keras.applications import DenseNet169, InceptionResNetV2, DenseNet201 |
|
|
17 |
from keras.applications import NASNetMobile |
|
|
18 |
from keras.layers import Dense, GlobalAveragePooling2D |
|
|
19 |
from keras.metrics import binary_accuracy, binary_crossentropy, kappa_error |
|
|
20 |
from keras.models import Model |
|
|
21 |
from keras.optimizers import Adam |
|
|
22 |
from keras.preprocessing.image import ImageDataGenerator |
|
|
23 |
from custom_layers import * |
|
|
24 |
from mura import Mura |
|
|
25 |
|
|
|
26 |
pd.set_option('display.max_rows', 20) |
|
|
27 |
pd.set_option('precision', 4) |
|
|
28 |
np.set_printoptions(precision=4) |
|
|
29 |
|
|
|
30 |
environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Shut up tensorflow! |
|
|
31 |
print("tf : {}".format(tf.__version__)) |
|
|
32 |
print("keras : {}".format(keras.__version__)) |
|
|
33 |
print("numpy : {}".format(np.__version__)) |
|
|
34 |
print("pandas : {}".format(pd.__version__)) |
|
|
35 |
print("sklearn : {}".format(skl.__version__)) |
|
|
36 |
|
|
|
37 |
# Hyper-parameters / Globals |
|
|
38 |
BATCH_SIZE = 4 # tweak to your GPUs capacity |
|
|
39 |
IMG_HEIGHT = 420 # ResNetInceptionv2 & Xception like 299, ResNet50/VGG/Inception 224, NASM 331 |
|
|
40 |
IMG_WIDTH = IMG_HEIGHT |
|
|
41 |
CHANNELS = 3 |
|
|
42 |
DIMS = (IMG_HEIGHT, IMG_WIDTH, CHANNELS) # blame theano |
|
|
43 |
MODEL_TO_EVAL1 = './models/DenseNet169_420_HUMERUS.hdf5' |
|
|
44 |
MODEL_TO_EVAL2 = './models/DenseNet169_420_HAND.hdf5' |
|
|
45 |
MODEL_TO_EVAL3 = './models/DenseNet169_420_FINGER.hdf5' |
|
|
46 |
MODEL_TO_EVAL4 = './models/DenseNet169_420_FOREARM.hdf5' |
|
|
47 |
MODEL_TO_EVAL5 = './models/DenseNet169_420_ELBOW.hdf5' |
|
|
48 |
MODEL_TO_EVAL6 = './models/DenseNet169_420_SHOULDER.hdf5' |
|
|
49 |
MODEL_TO_EVAL7 = './models/DenseNet169_420_WRIST.hdf5' |
|
|
50 |
MODEL_TO_EVAL8 = './models/DenseNet169_420_NEW_HIST.hdf5' |
|
|
51 |
DATA_DIR = 'data_val/' |
|
|
52 |
EVAL_CSV = 'valid.csv' |
|
|
53 |
EVAL_DIR = 'data/val/' |
|
|
54 |
|
|
|
55 |
parser = argparse.ArgumentParser(description='Input Path') |
|
|
56 |
parser.add_argument('input_filename',default='valid_image_paths.csv', type=str) |
|
|
57 |
parser.add_argument('output_path', default='prediction.csv', type=str) |
|
|
58 |
proc_data_dir = join(os.getcwd(), 'data/val/') |
|
|
59 |
proc_train_dir = join(proc_data_dir, 'train') |
|
|
60 |
proc_val_dir = join(proc_data_dir, 'val') |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
class ImageString(object): |
|
|
64 |
_patient_re = re.compile(r'patient(\d+)') |
|
|
65 |
_study_re = re.compile(r'study(\d+)') |
|
|
66 |
_image_re = re.compile(r'image(\d+)') |
|
|
67 |
_study_type_re = re.compile(r'XR_(\w+)') |
|
|
68 |
|
|
|
69 |
def __init__(self, img_filename): |
|
|
70 |
|
|
|
71 |
self.img_filename = img_filename |
|
|
72 |
self.patient = self._parse_patient() |
|
|
73 |
self.study = self._parse_study() |
|
|
74 |
self.image_num = self._parse_image() |
|
|
75 |
self.study_type = self._parse_study_type() |
|
|
76 |
self.image = self._parse_image() |
|
|
77 |
self.normal = self._parse_normal() |
|
|
78 |
self.valid = self._parse_valid() |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
def flat_file_name(self): |
|
|
82 |
return "{}_{}_patient{}_study{}_{}_image{}.png".format(self.valid, self.study_type, self.patient, self.study, |
|
|
83 |
self.normal, self.image) |
|
|
84 |
|
|
|
85 |
def _parse_patient(self): |
|
|
86 |
return int(self._patient_re.search(self.img_filename).group(1)) |
|
|
87 |
|
|
|
88 |
def _parse_study(self): |
|
|
89 |
return int(self._study_re.search(self.img_filename).group(1)) |
|
|
90 |
|
|
|
91 |
def _parse_image(self): |
|
|
92 |
return int(self._image_re.search(self.img_filename).group(1)) |
|
|
93 |
|
|
|
94 |
def _parse_study_type(self): |
|
|
95 |
return self._study_type_re.search(self.img_filename).group(1) |
|
|
96 |
|
|
|
97 |
def _parse_normal(self): |
|
|
98 |
return "normal" if ("negative" in self.img_filename) else "abnormal" |
|
|
99 |
|
|
|
100 |
def _parse_normal_label(self): |
|
|
101 |
return 1 if("negative" in self.img_filename) else 0 |
|
|
102 |
|
|
|
103 |
def _parse_valid(self): |
|
|
104 |
return "valid" if ("valid" in self.img_filename) else "test" |
|
|
105 |
|
|
|
106 |
def preprocess_img(img): |
|
|
107 |
# Histogram normalization in v channel |
|
|
108 |
hsv = color.rgb2hsv(img) |
|
|
109 |
hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2]) |
|
|
110 |
img = color.hsv2rgb(hsv) |
|
|
111 |
|
|
|
112 |
# central square crop |
|
|
113 |
min_side = min(img.shape[:-1]) |
|
|
114 |
centre = img.shape[0] // 2, img.shape[1] // 2 |
|
|
115 |
img = img[centre[0] - min_side // 2:centre[0] + min_side // 2, |
|
|
116 |
centre[1] - min_side // 2:centre[1] + min_side // 2, |
|
|
117 |
:] |
|
|
118 |
|
|
|
119 |
# rescale to standard size |
|
|
120 |
img = transform.resize(img, (IMG_SIZE, IMG_SIZE)) |
|
|
121 |
|
|
|
122 |
# roll color axis to axis 0 |
|
|
123 |
img = np.rollaxis(img, -1) |
|
|
124 |
|
|
|
125 |
return img |
|
|
126 |
|
|
|
127 |
def eval(args=None): |
|
|
128 |
|
|
|
129 |
args= parser.parse_args() |
|
|
130 |
|
|
|
131 |
# load up our csv with validation factors |
|
|
132 |
data_dir = join(getcwd(), DATA_DIR) |
|
|
133 |
eval_csv = join(data_dir, EVAL_CSV) |
|
|
134 |
|
|
|
135 |
true_labels=[] |
|
|
136 |
|
|
|
137 |
########################################### |
|
|
138 |
df = pd.read_csv(args.input_filename, names=['img', 'label'], header=None) |
|
|
139 |
samples = [tuple(x) for x in df.values] |
|
|
140 |
# for img, label in samples: |
|
|
141 |
# #assert ("negative" in img) is (label is 0) |
|
|
142 |
# enc = ImageString(img) |
|
|
143 |
# true_labels.append(enc._parse_normal_label()) |
|
|
144 |
# cat_dir = join(proc_val_dir, enc.normal) |
|
|
145 |
# if not os.path.exists(cat_dir): |
|
|
146 |
# os.makedirs(cat_dir) |
|
|
147 |
# shutil.copy2(enc.img_filename, join(cat_dir, enc.flat_file_name())) |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
########################################### |
|
|
151 |
|
|
|
152 |
eval_datagen = ImageDataGenerator(rescale=1./255 |
|
|
153 |
# , histogram_equalization=True |
|
|
154 |
) |
|
|
155 |
eval_generator = eval_datagen.flow_from_directory( |
|
|
156 |
EVAL_DIR, class_mode='binary', shuffle=False,target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE) |
|
|
157 |
n_samples = eval_generator.samples |
|
|
158 |
base_model = DenseNet169(input_shape=DIMS, weights='imagenet', include_top=False) #weights='imagenet' |
|
|
159 |
x = base_model.output |
|
|
160 |
x = GlobalAveragePooling2D(name='avg_pool')(x) # comment for RESNET |
|
|
161 |
# x = WildcatPool2d()(x) |
|
|
162 |
|
|
|
163 |
x = Dense(1, activation='sigmoid', name='predictions')(x) |
|
|
164 |
model = Model(inputs=base_model.input, outputs=x) |
|
|
165 |
model.load_weights(MODEL_TO_EVAL8) |
|
|
166 |
model.compile(optimizer=Adam(lr=1e-3) |
|
|
167 |
, loss=binary_crossentropy |
|
|
168 |
# , loss=kappa_error |
|
|
169 |
, metrics=['binary_accuracy']) |
|
|
170 |
score, acc = model.evaluate_generator(eval_generator, n_samples / BATCH_SIZE) |
|
|
171 |
print(model.metrics_names) |
|
|
172 |
print('==> Metrics with eval') |
|
|
173 |
print("loss :{:0.4f} \t Accuracy:{:0.4f}".format(score, acc)) |
|
|
174 |
y_pred = model.predict_generator(eval_generator, n_samples / BATCH_SIZE) |
|
|
175 |
|
|
|
176 |
# print(y_pred) |
|
|
177 |
# df_filenames = pd.Series(np.array(eval_generator.filenames), name='filenames') |
|
|
178 |
# df_classes = pd.Series(np.array(y_pred), name='classes') |
|
|
179 |
|
|
|
180 |
# prediction_data = pd.concat([df_filenames, df_classes,]) |
|
|
181 |
# prediction_data.to_csv(args.output_path + "/prediction.csv") |
|
|
182 |
|
|
|
183 |
mura = Mura(eval_generator.filenames, y_true = eval_generator.classes, y_pred1=y_pred, y_pred2=y_pred, y_pred3=y_pred, y_pred4= y_pred, y_pred5= y_pred, output_path= args.output_path) |
|
|
184 |
print(mura.metrics_by_encounter()) |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
if __name__ == '__main__': |
|
|
188 |
eval() |