--- a
+++ b/Interpretability/heatmap_IG_utils.py
@@ -0,0 +1,285 @@
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import numpy as np
+import sys
+
+
+# ## Using IG
+
+# ========== Functions =============
+def interpolate_images(baseline,
+                       image,
+                       alphas):
+  alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis]
+  baseline_x = tf.expand_dims(baseline, axis=0)
+  input_x = tf.expand_dims(image, axis=0)
+  delta = input_x - baseline_x
+  images = baseline_x +  alphas_x * delta
+  return images
+
+
+def compute_gradients(model, images, target_class_idx):
+  with tf.GradientTape() as tape:
+    tape.watch(images)
+    logits = model(images)
+    # logits is of shape (m_steps+1, nb_classes) 
+    # print("logits = model(images): ", logits.shape)
+    # probs output should be of shape (m_steps+1, )
+    probs = logits[:, target_class_idx]
+    # print("probs.shape: ", probs.shape)
+  return tape.gradient(probs, images)
+
+
+def integral_approximation(gradients):
+  # riemann_trapezoidal
+  grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
+  integrated_gradients = tf.math.reduce_mean(grads, axis=0)
+  return integrated_gradients
+
+
+@tf.function
+def integrated_gradients(model,
+                         baseline,
+                         image,
+                         target_class_idx,
+                         m_steps=50,
+                         batch_size=32):
+  # 1. Generate alphas.
+  alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)
+
+  # Initialize TensorArray outside loop to collect gradients.    
+  gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)
+    
+  # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.
+  for alpha in tf.range(0, len(alphas), batch_size):
+    from_ = alpha
+    to = tf.minimum(from_ + batch_size, len(alphas))
+    alpha_batch = alphas[from_:to]
+
+    # 2. Generate interpolated inputs between baseline and input.
+    interpolated_path_input_batch = interpolate_images(baseline=baseline,
+                                                       image=image,
+                                                       alphas=alpha_batch)
+
+    # 3. Compute gradients between model outputs and interpolated inputs.
+    gradient_batch = compute_gradients(model=model, images=interpolated_path_input_batch,
+                                       target_class_idx=target_class_idx)
+    
+    # Write batch indices and gradients to extend TensorArray.
+    gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch)    
+  
+  # Stack path gradients together row-wise into single tensor.
+  total_gradients = gradient_batches.stack()
+
+  # 4. Integral approximation through averaging gradients.
+  avg_gradients = integral_approximation(gradients=total_gradients)
+
+  # 5. Scale integrated gradients with respect to input.
+  integrated_gradients = (image - baseline) * avg_gradients
+
+  return integrated_gradients
+
+def convergence_check(model, attributions, baseline, input, target_class_idx):
+  """
+  Args:
+    model(keras.Model): A trained model to generate predictions and inspect.
+    baseline(Tensor): A 3D image tensor with the shape 
+      (image_height, image_width, 3) with the same shape as the input tensor.
+    input(Tensor): A 3D image tensor with the shape 
+      (image_height, image_width, 3).
+    target_class_idx(Tensor): An integer that corresponds to the correct 
+      ImageNet class index in the model's output predictions tensor. Default 
+        value is 50 steps.   
+  Returns:
+    (none): Prints scores and convergence delta to sys.stdout.
+  """
+  # Your model's prediction on the baseline tensor. Ideally, the baseline score
+  # should be close to zero.
+  baseline_prediction = model(tf.expand_dims(baseline, 0))
+  # print("baseline_prediction: ", baseline_prediction)
+  # baseline_prediction:  tf.Tensor([[2.1683295e-04 3.1699744e-04 4.6704659e-01 5.3241956e-01]], shape=(1, 4), dtype=float32)
+
+  baseline_score = baseline_prediction[0][target_class_idx]
+  # print("baseline_score: ", baseline_score)
+
+  # Your model's prediction and score on the input tensor.
+  input_prediction = model(tf.expand_dims(input, 0))
+  # print("input_prediction: ", input_prediction)
+  # input_prediction:  tf.Tensor([[7.4290162e-01 2.5709778e-01 6.0866233e-07 5.7874078e-10]], shape=(1, 4), dtype=float32)
+
+  input_score = input_prediction[0][target_class_idx]
+  # print("input_score: ", input_score)
+
+  # Sum of your IG prediction attributions.
+  # print("\tattributios: ", attributions)
+  ig_score = tf.math.reduce_sum(attributions)
+  delta = ig_score - (input_score - baseline_score)
+  # print("delta: ", delta)
+  try:
+    # Test your IG score is <= 5% of the input minus baseline score.
+    tf.debugging.assert_near(ig_score, (input_score - baseline_score), rtol=0.05)
+    tf.print('Approximation accuracy within 5%.', output_stream=sys.stdout)
+  except tf.errors.InvalidArgumentError:
+    tf.print('Increase or decrease m_steps to increase approximation accuracy.', output_stream=sys.stdout)
+  
+  tf.print('Baseline score: {:.3f}'.format(baseline_score))
+  tf.print('Input score: {:.3f}'.format(input_score))
+  tf.print('IG score: {:.3f}'.format(ig_score))     
+  tf.print('Convergence delta: {:.3f}'.format(delta))
+
+def plot_img_attributions(model,
+                          baseline,
+                          image,
+                          target_class_idx,
+                          m_steps=50,
+                          cmap=None,
+                          overlay_alpha=0.4,
+                          top_prob=0.0,
+                          top_label="",
+                          meta={}):
+  # print("\n@@@@@ plot_img_attributions called @@@@@\n")
+
+  attributions = integrated_gradients(model=model,
+                                      baseline=baseline,
+                                      image=image,
+                                      target_class_idx=target_class_idx,
+                                      m_steps=m_steps)
+  # print("\n\n\tAttributions: ", attributions)
+
+  convergence_check(model=model,
+                    attributions=attributions,
+                    baseline=baseline,
+                    input=image,
+                    target_class_idx=target_class_idx)
+
+  # Sum of the attributions across color channels for visualization.
+  # The attribution mask shape is a grayscale image with height and width
+  # equal to the original image.
+  attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)
+
+  fig, axs = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(9,4))
+
+  file_name = meta["file_name"]
+  v = meta["v"]
+  position = ""
+  mode = meta["mode"]
+
+  if mode == "Sag" and v == 1:  # sag only 1 label == 1
+    position = f'P{meta["position_index"]}'
+  elif mode == "Axial":
+    if v == 1:
+      position = "Right"
+    elif v == 3:
+      position = "Left"
+      # flip back the v=3 crops
+      attribution_mask = np.fliplr(attribution_mask)
+      image = np.fliplr(image)
+    elif v == 2:
+      position = "Center"
+
+  # axs[0, 0].set_title('Baseline image')
+  # axs[0, 0].imshow(baseline)
+  # axs[0, 0].axis('off')
+
+  axs[0, 0].set_title('Original image')
+  axs[0, 0].imshow(image)
+  axs[0, 0].axis('off')
+
+  axs[0, 1].set_title('Attribution mask')
+  axs[0, 1].imshow(attribution_mask, cmap=cmap)
+  axs[0, 1].axis('off')
+
+  axs[0, 2].set_title('Overlay')
+  axs[0, 2].imshow(attribution_mask, cmap=cmap)
+  axs[0, 2].imshow(image, alpha=overlay_alpha)
+  axs[0, 2].axis('off')
+
+  # title and png file save name
+  save_name = f'{file_name}-{mode}-{position}-{top_label}-{top_prob:0.1%}'
+  fig.suptitle(save_name, fontweight='bold')
+  plt.tight_layout()
+  # plt.show()  # this is needed to block the process
+  plt.savefig(f'{meta["save_dir"]}/{save_name}.jpeg')
+  # close figure by plt.close(fig), it won't be displayed
+  plt.close(fig)
+  return fig
+
+
+def main_ig(model, img_tensor, target_class_idx, prediction, meta):
+    """
+    input:
+        model: center, sag, or lateral model
+        img_tensor: tensor of the image for IG
+        target_class_idx: index of the top pred label
+        prediction: array of confidence in percentage
+        meta: dict of
+          file_name,
+          v,
+          mode,
+
+    """
+    # print("\n\n======== main_ig called ============")
+    # print("target_class_idx: ", target_class_idx)
+    top_prob = np.max(prediction[0])
+    grading = np.array(['normal', 'mild', 'moderate', 'severe'])
+    top_label = grading[target_class_idx]
+    # print("img_tensor: ", img_tensor.shape, img_tensor.dtype, img_tensor[0][0])
+    # ============ Constants ===================
+    baseline = tf.zeros(shape=(150,150,3))
+
+    # if needs to Visualizing gradient saturation
+    visualize_grad_saturation = False
+    if visualize_grad_saturation:
+        m_steps = 50
+        alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1) # Generate m_steps intervals for integral_approximation() below.
+
+        interpolated_images = interpolate_images(
+            baseline=baseline,
+            image=img_tensor,
+            alphas=alphas)
+
+        # ### Compute Gradients
+
+        path_gradients = compute_gradients(
+            model=model,
+            images=interpolated_images,
+            target_class_idx=target_class_idx)
+        # print("path_gradients: ", path_gradients.shape)
+        # print(np.max(path_gradients), np.min(path_gradients))
+        
+        # Visualize the gradient saturation
+        pred = model(interpolated_images)
+        pred_proba = pred[:, target_class_idx]
+
+        plt.figure(figsize=(10, 4))
+        ax1 = plt.subplot(1, 2, 1)
+        ax1.plot(alphas, pred_proba)
+        ax1.set_title('Target class predicted probability over alpha')
+        ax1.set_ylabel('model p(target class)')
+        ax1.set_xlabel('alpha')
+        ax1.set_ylim([0, 1])
+
+        ax2 = plt.subplot(1, 2, 2)
+        # Average across interpolation steps
+        average_grads = tf.reduce_mean(path_gradients, axis=[1, 2, 3])
+        # Normalize gradients to 0 to 1 scale. E.g. (x - min(x))/(max(x)-min(x))
+        average_grads_norm = (average_grads-tf.math.reduce_min(average_grads))/(tf.math.reduce_max(average_grads)-tf.reduce_min(average_grads))
+        ax2.plot(alphas, average_grads_norm)
+        ax2.set_title('Average pixel gradients (normalized) over alpha')
+        ax2.set_ylabel('Average pixel gradients')
+        ax2.set_xlabel('alpha')
+        ax2.set_ylim([0, 1]);
+        plt.show()
+
+    # =========== main program ================
+    # ## Visualize Attributions
+    _ = plot_img_attributions(model=model,
+                            image=img_tensor,
+                            baseline=baseline,
+                            target_class_idx=target_class_idx,
+                            m_steps=240,
+                            cmap=plt.cm.inferno,
+                            overlay_alpha=0.4,
+                            top_prob=top_prob,
+                            top_label=top_label,
+                            meta=meta)