Diff of /src/grad_cam_test.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/grad_cam_test.py
@@ -0,0 +1,272 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from tqdm import tqdm
+import os
+
+class _BaseWrapper(object):
+    def __init__(self, model):
+        super(_BaseWrapper, self).__init__()
+        self.device = next(model.parameters()).device
+        self.model = model
+        self.handlers = []  # a set of hook function handlers
+
+    def _encode_one_hot(self, ids):
+        one_hot = torch.zeros_like(self.logits).to(self.device)
+        one_hot.scatter_(1, ids, 1.0)
+        return one_hot
+
+    def forward(self, image):
+        self.image_shape = image.shape[2:]
+        self.logits = self.model(image)
+        self.probs = F.log_softmax(self.logits, dim=1)#softmax(self.logits, dim=1)
+        return self.probs.sort(dim=1, descending=True)  # ordered results
+
+    def backward(self, ids):
+        """
+        Class-specific backpropagation
+        """
+        one_hot = self._encode_one_hot(ids)
+        self.model.zero_grad()
+        self.logits.backward(gradient=one_hot, retain_graph=True)
+
+    def generate(self):
+        raise NotImplementedError
+
+    def remove_hook(self):
+        """
+        Remove all the forward/backward hook functions
+        """
+        for handle in self.handlers:
+            handle.remove()
+
+class BackPropagation(_BaseWrapper):
+    def forward(self, image):
+        self.image = image.requires_grad_()
+        return super(BackPropagation, self).forward(self.image)
+
+    def generate(self):
+        gradient = self.image.grad.clone()
+        self.image.grad.zero_()
+        return gradient
+
+class GradCAM(_BaseWrapper):
+    """
+    "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
+    https://arxiv.org/pdf/1610.02391.pdf
+    Look at Figure 2 on page 4
+    """
+
+    def __init__(self, model, candidate_layers=None):
+        super(GradCAM, self).__init__(model)
+        self.fmap_pool = {}
+        self.grad_pool = {}
+        self.candidate_layers = candidate_layers  # list
+
+        def save_fmaps(key):
+            def forward_hook(module, input, output):
+                self.fmap_pool[key] = output.detach()
+
+            return forward_hook
+
+        def save_grads(key):
+            def backward_hook(module, grad_in, grad_out):
+                self.grad_pool[key] = grad_out[0].detach()
+
+            return backward_hook
+
+        # If any candidates are not specified, the hook is registered to all the layers.
+        for name, module in self.model.named_modules():
+            if self.candidate_layers is None or name in self.candidate_layers:
+                self.handlers.append(module.register_forward_hook(save_fmaps(name)))
+                self.handlers.append(module.register_backward_hook(save_grads(name)))
+
+    def _find(self, pool, target_layer):
+        if target_layer in pool.keys():
+            return pool[target_layer]
+        else:
+            raise ValueError("Invalid layer name: {}".format(target_layer))
+
+    def generate(self, target_layer):
+        fmaps = self._find(self.fmap_pool, target_layer)
+        grads = self._find(self.grad_pool, target_layer)
+        weights = F.adaptive_avg_pool2d(grads, 1)
+
+        gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
+        gcam = F.relu(gcam)
+        gcam = F.interpolate(
+            gcam, self.image_shape, mode="bilinear", align_corners=False
+        )
+
+        B, C, H, W = gcam.shape
+        gcam = gcam.view(B, -1)
+        gcam -= gcam.min(dim=1, keepdim=True)[0]
+        gcam /= gcam.max(dim=1, keepdim=True)[0]
+        gcam = gcam.view(B, C, H, W)
+
+        return gcam
+
+import copy
+import os.path as osp
+
+import click
+import cv2
+import matplotlib.cm as cm
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torchvision import models, transforms
+
+
+def get_device(cuda):
+    cuda = cuda and torch.cuda.is_available()
+    device = torch.device("cuda" if cuda else "cpu")
+    if cuda:
+        current_device = torch.cuda.current_device()
+        print("Device:", torch.cuda.get_device_name(current_device))
+    else:
+        print("Device: CPU")
+    return device
+
+
+def load_images(images):
+    return preprocess(images)
+
+
+def preprocess(images):
+    inverse_norm = 255 * (0.5 * images + 0.5)
+    raw_images = (inverse_norm).numpy().transpose(0, 2, 3, 1)[..., ::-1]
+    return images, raw_images
+
+
+def save_gradient(filename, gradient):
+    gradient = gradient.cpu().numpy().transpose(1, 2, 0)
+    gradient -= gradient.min()
+    gradient /= gradient.max()
+    gradient *= 255.0
+    cv2.imwrite(filename, np.uint8(gradient))
+
+
+def save_gradcam(filename, gcam, raw_image, paper_cmap=False):
+    gcam = gcam.cpu().numpy()
+    cmap = cm.jet_r(gcam)[..., :3] * 255.0
+    if paper_cmap:
+        alpha = gcam[..., None]
+        gcam = alpha * cmap + (1 - alpha) * raw_image
+    else:
+        gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2
+    c0 = raw_image[..., 0]
+    c0 = np.stack((c0, c0, c0), axis=-1)
+    c1 = raw_image[..., 1]
+    c1 = np.stack((c1, c1, c1), axis=-1)
+    c2 = raw_image[..., 0]
+    c2 = np.stack((c2, c2, c2), axis=-1)
+    stack = np.concatenate((gcam, c0, c1, c2, raw_image), axis=1)
+    cv2.imwrite(filename, np.uint8(stack))
+
+def save_sensitivity(filename, maps):
+    maps = maps.cpu().numpy()
+    scale = max(maps[maps > 0].max(), -maps[maps <= 0].min())
+    maps = maps / scale * 0.5
+    maps += 0.5
+    maps = cm.bwr_r(maps)[..., :3]
+    maps = np.uint8(maps * 255.0)
+    maps = cv2.resize(maps, (224, 224), interpolation=cv2.INTER_NEAREST)
+    cv2.imwrite(filename, maps)
+
+
+def gc_test_old(model, dataset, experiment_dir, classes, device):
+    """
+    Visualize model responses given multiple images
+    """
+    target_layer = 'layer4'
+    topk = 1
+    output_dir = experiment_dir
+    from shutil import rmtree
+    if os.path.exists(output_dir): rmtree(output_dir)
+    os.makedirs(output_dir)
+    model.to(device)
+    model.eval()
+
+
+    for idx in range(len(dataset)):
+        image, image_path = dataset[idx]
+        image_name = os.path.split(image_path)[1]
+        images = torch.unsqueeze(image, 0)
+        images, raw_images = load_images(images)
+        images = images.to(device)
+
+        bp = BackPropagation(model=model)
+        probs, ids = bp.forward(images)  # sorted
+        for i in range(topk):
+            bp.backward(ids=ids[:, [i]])
+            gradients = bp.generate()
+        # Remove all the hook function in the "model"
+        bp.remove_hook()
+        # =====================================================================
+        #print("Grad-CAM/Guided Backpropagation/Guided Grad-CAM:")
+
+        gcam = GradCAM(model=model)
+        _ = gcam.forward(images)
+
+        for i in range(topk):
+            # Grad-CAM
+            gcam.backward(ids=ids[:, [i]])
+            regions = gcam.generate(target_layer=target_layer)
+
+            for j in range(len(images)):
+                #print("\t#{}: {} ({:.5f})".format(j, classes[ids[j, i]], probs[j, i]))
+                # Grad-CAM
+                result_path = osp.join(output_dir,
+                                       f'{classes[ids[j, i]]}-{image_name}')
+                save_gradcam(
+                    filename=result_path,
+                    gcam=regions[j, 0],
+                    raw_image=raw_images[j],
+                )
+
+
+def gc_test(model, dataset, results_dir, classes, device):
+    """
+    Visualize model responses given multiple images
+    """
+    target_layer = 'conv1'
+    topk = 1
+    output_dir = results_dir
+    from shutil import rmtree
+    if os.path.exists(output_dir): rmtree(output_dir)
+    os.makedirs(output_dir)
+    model.to(device)
+    model.eval()
+
+    for idx in range(len(dataset)):
+        image, image_path = dataset[idx]
+        image_name = os.path.split(image_path)[1]
+        images = torch.unsqueeze(image, 0)
+        images, raw_images = load_images(images)
+        images = images.to(device)
+
+        logits = model(images)
+        probs = F.softmax(logits, dim=1)
+        IH = classes[probs.argmax().item()]
+
+        # =====================================================================
+        #print("Grad-CAM/Guided Backpropagation/Guided Grad-CAM:")
+
+        gcam = GradCAM(model=model)
+        _ = gcam.forward(images)
+
+        # Grad-CAM
+        gcam.backward(ids=torch.Tensor([[1]]).long().to(device)) # IH class
+        regions = gcam.generate(target_layer=target_layer)
+
+        # Grad-CAM
+        image_name, ext = image_name.split('.')
+        result_name = f'{image_name}-ProbIH:{probs[0,1]:.4f}-{IH}.{ext}'
+        result_path = osp.join(output_dir, result_name)
+        save_gradcam(
+            filename=result_path,
+            gcam=regions[0, 0],
+            raw_image=raw_images[0],
+        )