Switch to unified view

a b/Interpretability/heatmap_IG_utils.py
1
import matplotlib.pyplot as plt
2
import tensorflow as tf
3
import numpy as np
4
import sys
5
6
7
# ## Using IG
8
9
# ========== Functions =============
10
def interpolate_images(baseline,
11
                       image,
12
                       alphas):
13
  alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis]
14
  baseline_x = tf.expand_dims(baseline, axis=0)
15
  input_x = tf.expand_dims(image, axis=0)
16
  delta = input_x - baseline_x
17
  images = baseline_x +  alphas_x * delta
18
  return images
19
20
21
def compute_gradients(model, images, target_class_idx):
22
  with tf.GradientTape() as tape:
23
    tape.watch(images)
24
    logits = model(images)
25
    # logits is of shape (m_steps+1, nb_classes) 
26
    # print("logits = model(images): ", logits.shape)
27
    # probs output should be of shape (m_steps+1, )
28
    probs = logits[:, target_class_idx]
29
    # print("probs.shape: ", probs.shape)
30
  return tape.gradient(probs, images)
31
32
33
def integral_approximation(gradients):
34
  # riemann_trapezoidal
35
  grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
36
  integrated_gradients = tf.math.reduce_mean(grads, axis=0)
37
  return integrated_gradients
38
39
40
@tf.function
41
def integrated_gradients(model,
42
                         baseline,
43
                         image,
44
                         target_class_idx,
45
                         m_steps=50,
46
                         batch_size=32):
47
  # 1. Generate alphas.
48
  alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)
49
50
  # Initialize TensorArray outside loop to collect gradients.    
51
  gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)
52
    
53
  # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.
54
  for alpha in tf.range(0, len(alphas), batch_size):
55
    from_ = alpha
56
    to = tf.minimum(from_ + batch_size, len(alphas))
57
    alpha_batch = alphas[from_:to]
58
59
    # 2. Generate interpolated inputs between baseline and input.
60
    interpolated_path_input_batch = interpolate_images(baseline=baseline,
61
                                                       image=image,
62
                                                       alphas=alpha_batch)
63
64
    # 3. Compute gradients between model outputs and interpolated inputs.
65
    gradient_batch = compute_gradients(model=model, images=interpolated_path_input_batch,
66
                                       target_class_idx=target_class_idx)
67
    
68
    # Write batch indices and gradients to extend TensorArray.
69
    gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch)    
70
  
71
  # Stack path gradients together row-wise into single tensor.
72
  total_gradients = gradient_batches.stack()
73
74
  # 4. Integral approximation through averaging gradients.
75
  avg_gradients = integral_approximation(gradients=total_gradients)
76
77
  # 5. Scale integrated gradients with respect to input.
78
  integrated_gradients = (image - baseline) * avg_gradients
79
80
  return integrated_gradients
81
82
def convergence_check(model, attributions, baseline, input, target_class_idx):
83
  """
84
  Args:
85
    model(keras.Model): A trained model to generate predictions and inspect.
86
    baseline(Tensor): A 3D image tensor with the shape 
87
      (image_height, image_width, 3) with the same shape as the input tensor.
88
    input(Tensor): A 3D image tensor with the shape 
89
      (image_height, image_width, 3).
90
    target_class_idx(Tensor): An integer that corresponds to the correct 
91
      ImageNet class index in the model's output predictions tensor. Default 
92
        value is 50 steps.   
93
  Returns:
94
    (none): Prints scores and convergence delta to sys.stdout.
95
  """
96
  # Your model's prediction on the baseline tensor. Ideally, the baseline score
97
  # should be close to zero.
98
  baseline_prediction = model(tf.expand_dims(baseline, 0))
99
  # print("baseline_prediction: ", baseline_prediction)
100
  # baseline_prediction:  tf.Tensor([[2.1683295e-04 3.1699744e-04 4.6704659e-01 5.3241956e-01]], shape=(1, 4), dtype=float32)
101
102
  baseline_score = baseline_prediction[0][target_class_idx]
103
  # print("baseline_score: ", baseline_score)
104
105
  # Your model's prediction and score on the input tensor.
106
  input_prediction = model(tf.expand_dims(input, 0))
107
  # print("input_prediction: ", input_prediction)
108
  # input_prediction:  tf.Tensor([[7.4290162e-01 2.5709778e-01 6.0866233e-07 5.7874078e-10]], shape=(1, 4), dtype=float32)
109
110
  input_score = input_prediction[0][target_class_idx]
111
  # print("input_score: ", input_score)
112
113
  # Sum of your IG prediction attributions.
114
  # print("\tattributios: ", attributions)
115
  ig_score = tf.math.reduce_sum(attributions)
116
  delta = ig_score - (input_score - baseline_score)
117
  # print("delta: ", delta)
118
  try:
119
    # Test your IG score is <= 5% of the input minus baseline score.
120
    tf.debugging.assert_near(ig_score, (input_score - baseline_score), rtol=0.05)
121
    tf.print('Approximation accuracy within 5%.', output_stream=sys.stdout)
122
  except tf.errors.InvalidArgumentError:
123
    tf.print('Increase or decrease m_steps to increase approximation accuracy.', output_stream=sys.stdout)
124
  
125
  tf.print('Baseline score: {:.3f}'.format(baseline_score))
126
  tf.print('Input score: {:.3f}'.format(input_score))
127
  tf.print('IG score: {:.3f}'.format(ig_score))     
128
  tf.print('Convergence delta: {:.3f}'.format(delta))
129
130
def plot_img_attributions(model,
131
                          baseline,
132
                          image,
133
                          target_class_idx,
134
                          m_steps=50,
135
                          cmap=None,
136
                          overlay_alpha=0.4,
137
                          top_prob=0.0,
138
                          top_label="",
139
                          meta={}):
140
  # print("\n@@@@@ plot_img_attributions called @@@@@\n")
141
142
  attributions = integrated_gradients(model=model,
143
                                      baseline=baseline,
144
                                      image=image,
145
                                      target_class_idx=target_class_idx,
146
                                      m_steps=m_steps)
147
  # print("\n\n\tAttributions: ", attributions)
148
149
  convergence_check(model=model,
150
                    attributions=attributions,
151
                    baseline=baseline,
152
                    input=image,
153
                    target_class_idx=target_class_idx)
154
155
  # Sum of the attributions across color channels for visualization.
156
  # The attribution mask shape is a grayscale image with height and width
157
  # equal to the original image.
158
  attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)
159
160
  fig, axs = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(9,4))
161
162
  file_name = meta["file_name"]
163
  v = meta["v"]
164
  position = ""
165
  mode = meta["mode"]
166
167
  if mode == "Sag" and v == 1:  # sag only 1 label == 1
168
    position = f'P{meta["position_index"]}'
169
  elif mode == "Axial":
170
    if v == 1:
171
      position = "Right"
172
    elif v == 3:
173
      position = "Left"
174
      # flip back the v=3 crops
175
      attribution_mask = np.fliplr(attribution_mask)
176
      image = np.fliplr(image)
177
    elif v == 2:
178
      position = "Center"
179
180
  # axs[0, 0].set_title('Baseline image')
181
  # axs[0, 0].imshow(baseline)
182
  # axs[0, 0].axis('off')
183
184
  axs[0, 0].set_title('Original image')
185
  axs[0, 0].imshow(image)
186
  axs[0, 0].axis('off')
187
188
  axs[0, 1].set_title('Attribution mask')
189
  axs[0, 1].imshow(attribution_mask, cmap=cmap)
190
  axs[0, 1].axis('off')
191
192
  axs[0, 2].set_title('Overlay')
193
  axs[0, 2].imshow(attribution_mask, cmap=cmap)
194
  axs[0, 2].imshow(image, alpha=overlay_alpha)
195
  axs[0, 2].axis('off')
196
197
  # title and png file save name
198
  save_name = f'{file_name}-{mode}-{position}-{top_label}-{top_prob:0.1%}'
199
  fig.suptitle(save_name, fontweight='bold')
200
  plt.tight_layout()
201
  # plt.show()  # this is needed to block the process
202
  plt.savefig(f'{meta["save_dir"]}/{save_name}.jpeg')
203
  # close figure by plt.close(fig), it won't be displayed
204
  plt.close(fig)
205
  return fig
206
207
208
def main_ig(model, img_tensor, target_class_idx, prediction, meta):
209
    """
210
    input:
211
        model: center, sag, or lateral model
212
        img_tensor: tensor of the image for IG
213
        target_class_idx: index of the top pred label
214
        prediction: array of confidence in percentage
215
        meta: dict of
216
          file_name,
217
          v,
218
          mode,
219
220
    """
221
    # print("\n\n======== main_ig called ============")
222
    # print("target_class_idx: ", target_class_idx)
223
    top_prob = np.max(prediction[0])
224
    grading = np.array(['normal', 'mild', 'moderate', 'severe'])
225
    top_label = grading[target_class_idx]
226
    # print("img_tensor: ", img_tensor.shape, img_tensor.dtype, img_tensor[0][0])
227
    # ============ Constants ===================
228
    baseline = tf.zeros(shape=(150,150,3))
229
230
    # if needs to Visualizing gradient saturation
231
    visualize_grad_saturation = False
232
    if visualize_grad_saturation:
233
        m_steps = 50
234
        alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1) # Generate m_steps intervals for integral_approximation() below.
235
236
        interpolated_images = interpolate_images(
237
            baseline=baseline,
238
            image=img_tensor,
239
            alphas=alphas)
240
241
        # ### Compute Gradients
242
243
        path_gradients = compute_gradients(
244
            model=model,
245
            images=interpolated_images,
246
            target_class_idx=target_class_idx)
247
        # print("path_gradients: ", path_gradients.shape)
248
        # print(np.max(path_gradients), np.min(path_gradients))
249
        
250
        # Visualize the gradient saturation
251
        pred = model(interpolated_images)
252
        pred_proba = pred[:, target_class_idx]
253
254
        plt.figure(figsize=(10, 4))
255
        ax1 = plt.subplot(1, 2, 1)
256
        ax1.plot(alphas, pred_proba)
257
        ax1.set_title('Target class predicted probability over alpha')
258
        ax1.set_ylabel('model p(target class)')
259
        ax1.set_xlabel('alpha')
260
        ax1.set_ylim([0, 1])
261
262
        ax2 = plt.subplot(1, 2, 2)
263
        # Average across interpolation steps
264
        average_grads = tf.reduce_mean(path_gradients, axis=[1, 2, 3])
265
        # Normalize gradients to 0 to 1 scale. E.g. (x - min(x))/(max(x)-min(x))
266
        average_grads_norm = (average_grads-tf.math.reduce_min(average_grads))/(tf.math.reduce_max(average_grads)-tf.reduce_min(average_grads))
267
        ax2.plot(alphas, average_grads_norm)
268
        ax2.set_title('Average pixel gradients (normalized) over alpha')
269
        ax2.set_ylabel('Average pixel gradients')
270
        ax2.set_xlabel('alpha')
271
        ax2.set_ylim([0, 1]);
272
        plt.show()
273
274
    # =========== main program ================
275
    # ## Visualize Attributions
276
    _ = plot_img_attributions(model=model,
277
                            image=img_tensor,
278
                            baseline=baseline,
279
                            target_class_idx=target_class_idx,
280
                            m_steps=240,
281
                            cmap=plt.cm.inferno,
282
                            overlay_alpha=0.4,
283
                            top_prob=top_prob,
284
                            top_label=top_label,
285
                            meta=meta)