[314dda]: / lab_cxr_scripts / lab_cxr / CXRLocalization.py

Download this file

128 lines (99 with data), 5.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import numpy as np
# Display
from PIL import Image
from matplotlib.cm import get_cmap
# Tensorflow Keras imports
from tensorflow import GradientTape, reduce_mean
from tensorflow.keras import preprocessing
from tensorflow.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras import Model, Input
# System library
from lab_cxr_scripts.lab_cxr import CXRPrediction
# Localization Algorithm is ADAPTED and MODIFIED from blow referenced paper
# Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D., 2020.
# Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. Int J Comput Vis 128, 336–359.
# https://doi.org/10.1007/s11263-019-01228-7
def generate_detection_heatmap(
img_array, model, xray_labels, end_conv_layer_name, classifier_layer_names
):
# New model linking the original model's inputs to the last convolutional layer
end_conv_layer = model.get_layer(end_conv_layer_name)
end_conv_layer_model = Model(model.inputs, end_conv_layer.output)
# Second model mapping the activation of the last convolutional layer to the final dense layer
classifier_layer_input = Input(shape=end_conv_layer.output.shape[1:])
x = classifier_layer_input
for layer_name in classifier_layer_names:
x = model.get_layer(layer_name)(x)
temp_classifier_model = Model(classifier_layer_input, x)
heatmaps = [] # Holds the detection_heatmap info for each pathology
# Iterating through all the pathologies to generate detection_heatmap for each
for pathology_index in range(0, len(xray_labels)):
# Gradient calculation of the predicted classes for the provided input relative to the output activations of the
# last convolutional layer
with GradientTape() as tape:
# Generating the activations of the end convolutional layer of the model
end_conv_layer_output = end_conv_layer_model(img_array)
tape.watch(end_conv_layer_output) # Keeps records of the generation of the activations
# Calculating the output for each 'pathology_index'
preds = temp_classifier_model(end_conv_layer_output)
seletcted_class_channel = preds[:, pathology_index]
# Gradient of the selected class_index
class_grads = tape.gradient(seletcted_class_channel, end_conv_layer_output)
# Vectorization of the generated gradients
pooled_class_grads = reduce_mean(class_grads, axis=(0, 1, 2))
# Weighting the calculated channel output values by its contribution towards the considered class
end_conv_layer_output = end_conv_layer_output.numpy()[0]
pooled_class_grads = pooled_class_grads.numpy()
for i in range(pooled_class_grads.shape[-1]):
end_conv_layer_output[:, :, i] *= pooled_class_grads[i]
# Calculating the mean across the activation of channels to generate the final detection_heatmap
detection_heatmap = np.mean(end_conv_layer_output, axis=-1)
# Normalizing the generated heatmap to 0 - 1 to make it possible for colorized illustration
detection_heatmap = np.maximum(detection_heatmap, 0) / np.max(detection_heatmap)
heatmaps.append(detection_heatmap) # appending to the heatmap list
return heatmaps
def create_cxr_localization_heatmap(cxr_hash, model, xray_labels, last_conv_layer_name="concatenate",
classifier_layer_names=None):
if classifier_layer_names is None: # mutable argument
classifier_layer_names = [
"global_average_pooling2d_1",
"dense_2",
]
cxr_img_path = './uploads/' + cxr_hash
# Preparing input cxr image - Resizing and Normalizing
img_array = preprocess_input(CXRPrediction.generate_cxr_img_array(cxr_img_path, size=(320, 320)), mode='tf')
# Generating class activation based heatmap
heatmaps = generate_detection_heatmap(
img_array, model, xray_labels, last_conv_layer_name, classifier_layer_names
)
# Iterating through all the pathologies to generate detection_heatmap for each
for pathology_index in range(0, len(xray_labels)):
heatmap = heatmaps[pathology_index]
cxr_img = preprocessing.image.load_img(cxr_img_path)
cxr_img = preprocessing.image.img_to_array(cxr_img)
# Scaling the heatmap values to be in range 0 - 255
heatmap = np.uint8(255 * heatmap)
# Colorizing the heatmap
colormap = get_cmap("jet") # Using jet colormap
# RGB values of the generated colormap
cm_colors = colormap(np.arange(256))[:, :3]
detection_heatmap = cm_colors[heatmap]
# Creating the image with colorized (RGB) heatmap
detection_heatmap = preprocessing.image.array_to_img(detection_heatmap)
detection_heatmap = detection_heatmap.resize((cxr_img.shape[1], cxr_img.shape[0]), Image.ANTIALIAS)
detection_heatmap = preprocessing.image.img_to_array(detection_heatmap)
# Superimpose the heatmap onto the original unprocessed cxr image
superimposed_img = detection_heatmap * 0.4 + cxr_img
superimposed_img = preprocessing.image.array_to_img(superimposed_img)
org_cxr_image = Image.open(cxr_img_path).convert('L')
org_3d_cxr_image = np.tile(np.array(org_cxr_image)[:, :, None], [1, 1, 3]) # 3-Channel original input cxr image
# Save path for the generated heatmap image
localized_image_name = xray_labels[pathology_index] + '-localizedHeatmap-' + cxr_hash
file_path = './localizations/' + cxr_hash.split('.')[0] + '/'
# Make the directory if not exists
if not os.path.exists(file_path):
os.makedirs(file_path)
Image.fromarray(
np.hstack((org_3d_cxr_image, np.array(superimposed_img)))
).save(file_path + localized_image_name)