|
a |
|
b/Interpretability/heatmap_IG_utils.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import tensorflow as tf |
|
|
3 |
import numpy as np |
|
|
4 |
import sys |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
# ## Using IG |
|
|
8 |
|
|
|
9 |
# ========== Functions ============= |
|
|
10 |
def interpolate_images(baseline, |
|
|
11 |
image, |
|
|
12 |
alphas): |
|
|
13 |
alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis] |
|
|
14 |
baseline_x = tf.expand_dims(baseline, axis=0) |
|
|
15 |
input_x = tf.expand_dims(image, axis=0) |
|
|
16 |
delta = input_x - baseline_x |
|
|
17 |
images = baseline_x + alphas_x * delta |
|
|
18 |
return images |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def compute_gradients(model, images, target_class_idx): |
|
|
22 |
with tf.GradientTape() as tape: |
|
|
23 |
tape.watch(images) |
|
|
24 |
logits = model(images) |
|
|
25 |
# logits is of shape (m_steps+1, nb_classes) |
|
|
26 |
# print("logits = model(images): ", logits.shape) |
|
|
27 |
# probs output should be of shape (m_steps+1, ) |
|
|
28 |
probs = logits[:, target_class_idx] |
|
|
29 |
# print("probs.shape: ", probs.shape) |
|
|
30 |
return tape.gradient(probs, images) |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
def integral_approximation(gradients): |
|
|
34 |
# riemann_trapezoidal |
|
|
35 |
grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0) |
|
|
36 |
integrated_gradients = tf.math.reduce_mean(grads, axis=0) |
|
|
37 |
return integrated_gradients |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
@tf.function |
|
|
41 |
def integrated_gradients(model, |
|
|
42 |
baseline, |
|
|
43 |
image, |
|
|
44 |
target_class_idx, |
|
|
45 |
m_steps=50, |
|
|
46 |
batch_size=32): |
|
|
47 |
# 1. Generate alphas. |
|
|
48 |
alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1) |
|
|
49 |
|
|
|
50 |
# Initialize TensorArray outside loop to collect gradients. |
|
|
51 |
gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1) |
|
|
52 |
|
|
|
53 |
# Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps. |
|
|
54 |
for alpha in tf.range(0, len(alphas), batch_size): |
|
|
55 |
from_ = alpha |
|
|
56 |
to = tf.minimum(from_ + batch_size, len(alphas)) |
|
|
57 |
alpha_batch = alphas[from_:to] |
|
|
58 |
|
|
|
59 |
# 2. Generate interpolated inputs between baseline and input. |
|
|
60 |
interpolated_path_input_batch = interpolate_images(baseline=baseline, |
|
|
61 |
image=image, |
|
|
62 |
alphas=alpha_batch) |
|
|
63 |
|
|
|
64 |
# 3. Compute gradients between model outputs and interpolated inputs. |
|
|
65 |
gradient_batch = compute_gradients(model=model, images=interpolated_path_input_batch, |
|
|
66 |
target_class_idx=target_class_idx) |
|
|
67 |
|
|
|
68 |
# Write batch indices and gradients to extend TensorArray. |
|
|
69 |
gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch) |
|
|
70 |
|
|
|
71 |
# Stack path gradients together row-wise into single tensor. |
|
|
72 |
total_gradients = gradient_batches.stack() |
|
|
73 |
|
|
|
74 |
# 4. Integral approximation through averaging gradients. |
|
|
75 |
avg_gradients = integral_approximation(gradients=total_gradients) |
|
|
76 |
|
|
|
77 |
# 5. Scale integrated gradients with respect to input. |
|
|
78 |
integrated_gradients = (image - baseline) * avg_gradients |
|
|
79 |
|
|
|
80 |
return integrated_gradients |
|
|
81 |
|
|
|
82 |
def convergence_check(model, attributions, baseline, input, target_class_idx): |
|
|
83 |
""" |
|
|
84 |
Args: |
|
|
85 |
model(keras.Model): A trained model to generate predictions and inspect. |
|
|
86 |
baseline(Tensor): A 3D image tensor with the shape |
|
|
87 |
(image_height, image_width, 3) with the same shape as the input tensor. |
|
|
88 |
input(Tensor): A 3D image tensor with the shape |
|
|
89 |
(image_height, image_width, 3). |
|
|
90 |
target_class_idx(Tensor): An integer that corresponds to the correct |
|
|
91 |
ImageNet class index in the model's output predictions tensor. Default |
|
|
92 |
value is 50 steps. |
|
|
93 |
Returns: |
|
|
94 |
(none): Prints scores and convergence delta to sys.stdout. |
|
|
95 |
""" |
|
|
96 |
# Your model's prediction on the baseline tensor. Ideally, the baseline score |
|
|
97 |
# should be close to zero. |
|
|
98 |
baseline_prediction = model(tf.expand_dims(baseline, 0)) |
|
|
99 |
# print("baseline_prediction: ", baseline_prediction) |
|
|
100 |
# baseline_prediction: tf.Tensor([[2.1683295e-04 3.1699744e-04 4.6704659e-01 5.3241956e-01]], shape=(1, 4), dtype=float32) |
|
|
101 |
|
|
|
102 |
baseline_score = baseline_prediction[0][target_class_idx] |
|
|
103 |
# print("baseline_score: ", baseline_score) |
|
|
104 |
|
|
|
105 |
# Your model's prediction and score on the input tensor. |
|
|
106 |
input_prediction = model(tf.expand_dims(input, 0)) |
|
|
107 |
# print("input_prediction: ", input_prediction) |
|
|
108 |
# input_prediction: tf.Tensor([[7.4290162e-01 2.5709778e-01 6.0866233e-07 5.7874078e-10]], shape=(1, 4), dtype=float32) |
|
|
109 |
|
|
|
110 |
input_score = input_prediction[0][target_class_idx] |
|
|
111 |
# print("input_score: ", input_score) |
|
|
112 |
|
|
|
113 |
# Sum of your IG prediction attributions. |
|
|
114 |
# print("\tattributios: ", attributions) |
|
|
115 |
ig_score = tf.math.reduce_sum(attributions) |
|
|
116 |
delta = ig_score - (input_score - baseline_score) |
|
|
117 |
# print("delta: ", delta) |
|
|
118 |
try: |
|
|
119 |
# Test your IG score is <= 5% of the input minus baseline score. |
|
|
120 |
tf.debugging.assert_near(ig_score, (input_score - baseline_score), rtol=0.05) |
|
|
121 |
tf.print('Approximation accuracy within 5%.', output_stream=sys.stdout) |
|
|
122 |
except tf.errors.InvalidArgumentError: |
|
|
123 |
tf.print('Increase or decrease m_steps to increase approximation accuracy.', output_stream=sys.stdout) |
|
|
124 |
|
|
|
125 |
tf.print('Baseline score: {:.3f}'.format(baseline_score)) |
|
|
126 |
tf.print('Input score: {:.3f}'.format(input_score)) |
|
|
127 |
tf.print('IG score: {:.3f}'.format(ig_score)) |
|
|
128 |
tf.print('Convergence delta: {:.3f}'.format(delta)) |
|
|
129 |
|
|
|
130 |
def plot_img_attributions(model, |
|
|
131 |
baseline, |
|
|
132 |
image, |
|
|
133 |
target_class_idx, |
|
|
134 |
m_steps=50, |
|
|
135 |
cmap=None, |
|
|
136 |
overlay_alpha=0.4, |
|
|
137 |
top_prob=0.0, |
|
|
138 |
top_label="", |
|
|
139 |
meta={}): |
|
|
140 |
# print("\n@@@@@ plot_img_attributions called @@@@@\n") |
|
|
141 |
|
|
|
142 |
attributions = integrated_gradients(model=model, |
|
|
143 |
baseline=baseline, |
|
|
144 |
image=image, |
|
|
145 |
target_class_idx=target_class_idx, |
|
|
146 |
m_steps=m_steps) |
|
|
147 |
# print("\n\n\tAttributions: ", attributions) |
|
|
148 |
|
|
|
149 |
convergence_check(model=model, |
|
|
150 |
attributions=attributions, |
|
|
151 |
baseline=baseline, |
|
|
152 |
input=image, |
|
|
153 |
target_class_idx=target_class_idx) |
|
|
154 |
|
|
|
155 |
# Sum of the attributions across color channels for visualization. |
|
|
156 |
# The attribution mask shape is a grayscale image with height and width |
|
|
157 |
# equal to the original image. |
|
|
158 |
attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1) |
|
|
159 |
|
|
|
160 |
fig, axs = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(9,4)) |
|
|
161 |
|
|
|
162 |
file_name = meta["file_name"] |
|
|
163 |
v = meta["v"] |
|
|
164 |
position = "" |
|
|
165 |
mode = meta["mode"] |
|
|
166 |
|
|
|
167 |
if mode == "Sag" and v == 1: # sag only 1 label == 1 |
|
|
168 |
position = f'P{meta["position_index"]}' |
|
|
169 |
elif mode == "Axial": |
|
|
170 |
if v == 1: |
|
|
171 |
position = "Right" |
|
|
172 |
elif v == 3: |
|
|
173 |
position = "Left" |
|
|
174 |
# flip back the v=3 crops |
|
|
175 |
attribution_mask = np.fliplr(attribution_mask) |
|
|
176 |
image = np.fliplr(image) |
|
|
177 |
elif v == 2: |
|
|
178 |
position = "Center" |
|
|
179 |
|
|
|
180 |
# axs[0, 0].set_title('Baseline image') |
|
|
181 |
# axs[0, 0].imshow(baseline) |
|
|
182 |
# axs[0, 0].axis('off') |
|
|
183 |
|
|
|
184 |
axs[0, 0].set_title('Original image') |
|
|
185 |
axs[0, 0].imshow(image) |
|
|
186 |
axs[0, 0].axis('off') |
|
|
187 |
|
|
|
188 |
axs[0, 1].set_title('Attribution mask') |
|
|
189 |
axs[0, 1].imshow(attribution_mask, cmap=cmap) |
|
|
190 |
axs[0, 1].axis('off') |
|
|
191 |
|
|
|
192 |
axs[0, 2].set_title('Overlay') |
|
|
193 |
axs[0, 2].imshow(attribution_mask, cmap=cmap) |
|
|
194 |
axs[0, 2].imshow(image, alpha=overlay_alpha) |
|
|
195 |
axs[0, 2].axis('off') |
|
|
196 |
|
|
|
197 |
# title and png file save name |
|
|
198 |
save_name = f'{file_name}-{mode}-{position}-{top_label}-{top_prob:0.1%}' |
|
|
199 |
fig.suptitle(save_name, fontweight='bold') |
|
|
200 |
plt.tight_layout() |
|
|
201 |
# plt.show() # this is needed to block the process |
|
|
202 |
plt.savefig(f'{meta["save_dir"]}/{save_name}.jpeg') |
|
|
203 |
# close figure by plt.close(fig), it won't be displayed |
|
|
204 |
plt.close(fig) |
|
|
205 |
return fig |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
def main_ig(model, img_tensor, target_class_idx, prediction, meta): |
|
|
209 |
""" |
|
|
210 |
input: |
|
|
211 |
model: center, sag, or lateral model |
|
|
212 |
img_tensor: tensor of the image for IG |
|
|
213 |
target_class_idx: index of the top pred label |
|
|
214 |
prediction: array of confidence in percentage |
|
|
215 |
meta: dict of |
|
|
216 |
file_name, |
|
|
217 |
v, |
|
|
218 |
mode, |
|
|
219 |
|
|
|
220 |
""" |
|
|
221 |
# print("\n\n======== main_ig called ============") |
|
|
222 |
# print("target_class_idx: ", target_class_idx) |
|
|
223 |
top_prob = np.max(prediction[0]) |
|
|
224 |
grading = np.array(['normal', 'mild', 'moderate', 'severe']) |
|
|
225 |
top_label = grading[target_class_idx] |
|
|
226 |
# print("img_tensor: ", img_tensor.shape, img_tensor.dtype, img_tensor[0][0]) |
|
|
227 |
# ============ Constants =================== |
|
|
228 |
baseline = tf.zeros(shape=(150,150,3)) |
|
|
229 |
|
|
|
230 |
# if needs to Visualizing gradient saturation |
|
|
231 |
visualize_grad_saturation = False |
|
|
232 |
if visualize_grad_saturation: |
|
|
233 |
m_steps = 50 |
|
|
234 |
alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1) # Generate m_steps intervals for integral_approximation() below. |
|
|
235 |
|
|
|
236 |
interpolated_images = interpolate_images( |
|
|
237 |
baseline=baseline, |
|
|
238 |
image=img_tensor, |
|
|
239 |
alphas=alphas) |
|
|
240 |
|
|
|
241 |
# ### Compute Gradients |
|
|
242 |
|
|
|
243 |
path_gradients = compute_gradients( |
|
|
244 |
model=model, |
|
|
245 |
images=interpolated_images, |
|
|
246 |
target_class_idx=target_class_idx) |
|
|
247 |
# print("path_gradients: ", path_gradients.shape) |
|
|
248 |
# print(np.max(path_gradients), np.min(path_gradients)) |
|
|
249 |
|
|
|
250 |
# Visualize the gradient saturation |
|
|
251 |
pred = model(interpolated_images) |
|
|
252 |
pred_proba = pred[:, target_class_idx] |
|
|
253 |
|
|
|
254 |
plt.figure(figsize=(10, 4)) |
|
|
255 |
ax1 = plt.subplot(1, 2, 1) |
|
|
256 |
ax1.plot(alphas, pred_proba) |
|
|
257 |
ax1.set_title('Target class predicted probability over alpha') |
|
|
258 |
ax1.set_ylabel('model p(target class)') |
|
|
259 |
ax1.set_xlabel('alpha') |
|
|
260 |
ax1.set_ylim([0, 1]) |
|
|
261 |
|
|
|
262 |
ax2 = plt.subplot(1, 2, 2) |
|
|
263 |
# Average across interpolation steps |
|
|
264 |
average_grads = tf.reduce_mean(path_gradients, axis=[1, 2, 3]) |
|
|
265 |
# Normalize gradients to 0 to 1 scale. E.g. (x - min(x))/(max(x)-min(x)) |
|
|
266 |
average_grads_norm = (average_grads-tf.math.reduce_min(average_grads))/(tf.math.reduce_max(average_grads)-tf.reduce_min(average_grads)) |
|
|
267 |
ax2.plot(alphas, average_grads_norm) |
|
|
268 |
ax2.set_title('Average pixel gradients (normalized) over alpha') |
|
|
269 |
ax2.set_ylabel('Average pixel gradients') |
|
|
270 |
ax2.set_xlabel('alpha') |
|
|
271 |
ax2.set_ylim([0, 1]); |
|
|
272 |
plt.show() |
|
|
273 |
|
|
|
274 |
# =========== main program ================ |
|
|
275 |
# ## Visualize Attributions |
|
|
276 |
_ = plot_img_attributions(model=model, |
|
|
277 |
image=img_tensor, |
|
|
278 |
baseline=baseline, |
|
|
279 |
target_class_idx=target_class_idx, |
|
|
280 |
m_steps=240, |
|
|
281 |
cmap=plt.cm.inferno, |
|
|
282 |
overlay_alpha=0.4, |
|
|
283 |
top_prob=top_prob, |
|
|
284 |
top_label=top_label, |
|
|
285 |
meta=meta) |