import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import os
class _BaseWrapper(object):
def __init__(self, model):
super(_BaseWrapper, self).__init__()
self.device = next(model.parameters()).device
self.model = model
self.handlers = [] # a set of hook function handlers
def _encode_one_hot(self, ids):
one_hot = torch.zeros_like(self.logits).to(self.device)
one_hot.scatter_(1, ids, 1.0)
return one_hot
def forward(self, image):
self.image_shape = image.shape[2:]
self.logits = self.model(image)
self.probs = F.softmax(self.logits, dim=1)
return self.probs.sort(dim=1, descending=True) # ordered results
def backward(self, ids):
"""
Class-specific backpropagation
"""
one_hot = self._encode_one_hot(ids)
self.model.zero_grad()
self.logits.backward(gradient=one_hot, retain_graph=True)
def generate(self):
raise NotImplementedError
def remove_hook(self):
"""
Remove all the forward/backward hook functions
"""
for handle in self.handlers:
handle.remove()
class BackPropagation(_BaseWrapper):
def forward(self, image):
self.image = image.requires_grad_()
return super(BackPropagation, self).forward(self.image)
def generate(self):
gradient = self.image.grad.clone()
self.image.grad.zero_()
return gradient
class GradCAM(_BaseWrapper):
"""
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
https://arxiv.org/pdf/1610.02391.pdf
Look at Figure 2 on page 4
"""
def __init__(self, model, candidate_layers=None):
super(GradCAM, self).__init__(model)
self.fmap_pool = {}
self.grad_pool = {}
self.candidate_layers = candidate_layers # list
def save_fmaps(key):
def forward_hook(module, input, output):
self.fmap_pool[key] = output.detach()
return forward_hook
def save_grads(key):
def backward_hook(module, grad_in, grad_out):
self.grad_pool[key] = grad_out[0].detach()
return backward_hook
# If any candidates are not specified, the hook is registered to all the layers.
for name, module in self.model.named_modules():
if self.candidate_layers is None or name in self.candidate_layers:
self.handlers.append(module.register_forward_hook(save_fmaps(name)))
self.handlers.append(module.register_backward_hook(save_grads(name)))
def _find(self, pool, target_layer):
if target_layer in pool.keys():
return pool[target_layer]
else:
raise ValueError("Invalid layer name: {}".format(target_layer))
def generate(self, target_layer):
fmaps = self._find(self.fmap_pool, target_layer)
grads = self._find(self.grad_pool, target_layer)
weights = F.adaptive_avg_pool2d(grads, 1)
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
gcam = F.relu(gcam)
gcam = F.interpolate(
gcam, self.image_shape, mode="bilinear", align_corners=False
)
B, C, H, W = gcam.shape
gcam = gcam.view(B, -1)
gcam -= gcam.min(dim=1, keepdim=True)[0]
gcam /= gcam.max(dim=1, keepdim=True)[0]
gcam = gcam.view(B, C, H, W)
return gcam
import copy
import os.path as osp
import click
import cv2
import matplotlib.cm as cm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import models, transforms
# if a model includes LSTM, such as in image captioning,
# torch.backends.cudnn.enabled = False
def get_device(cuda):
cuda = cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
if cuda:
current_device = torch.cuda.current_device()
print("Device:", torch.cuda.get_device_name(current_device))
else:
print("Device: CPU")
return device
'''
def load_images(image_paths):
images = []
raw_images = []
print("Images:")
for i, image_path in enumerate(image_paths):
print("\t#{}: {}".format(i, image_path))
image, raw_image = preprocess(image_path)
images.append(image)
raw_images.append(raw_image)
return images, raw_images
def preprocess(image_path):
raw_image = cv2.imread(image_path)
raw_image = cv2.resize(raw_image, (224,) * 2)
image = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
]
)(raw_image[..., ::-1].copy())
return image, raw_image
'''
def load_images(images):
return preprocess(images)
def preprocess(images):
inverse_norm = 255 * (0.5 * images + 0.5)
raw_images = (inverse_norm).numpy().transpose(0, 2, 3, 1)[..., ::-1]
return images, raw_images
def save_gradient(filename, gradient):
gradient = gradient.cpu().numpy().transpose(1, 2, 0)
gradient -= gradient.min()
gradient /= gradient.max()
gradient *= 255.0
cv2.imwrite(filename, np.uint8(gradient))
def save_gradcam(filename, gcam, raw_image, paper_cmap=False):
gcam = gcam.cpu().numpy()
cmap = cm.jet_r(gcam)[..., :3] * 255.0
if paper_cmap:
alpha = gcam[..., None]
gcam = alpha * cmap + (1 - alpha) * raw_image
else:
gcam = (cmap.astype(np.float) + raw_image.astype(np.float)) / 2
#cv2.imwrite(filename, np.uint8(gcam))
c0 = raw_image[..., 0]
c0 = np.stack((c0, c0, c0), axis=-1)
c1 = raw_image[..., 1]
c1 = np.stack((c1, c1, c1), axis=-1)
c2 = raw_image[..., 0]
c2 = np.stack((c2, c2, c2), axis=-1)
stack = np.concatenate((gcam, c0, c1, c2, raw_image), axis=1)
cv2.imwrite(filename, np.uint8(stack))
def save_sensitivity(filename, maps):
maps = maps.cpu().numpy()
scale = max(maps[maps > 0].max(), -maps[maps <= 0].min())
maps = maps / scale * 0.5
maps += 0.5
maps = cm.bwr_r(maps)[..., :3]
maps = np.uint8(maps * 255.0)
maps = cv2.resize(maps, (224, 224), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(filename, maps)
def gc(model, dataloader, experiment_dir, classes, device):
"""
Visualize model responses given multiple images
"""
arch = ''
target_layer = 'layer4'
topk = 1
output_dir = os.path.join(experiment_dir, 'Grad_CAM')
from shutil import rmtree
if os.path.exists(output_dir): rmtree(output_dir)
os.makedirs(output_dir)
model.to(device)
model.eval()
# Images
#images, raw_images = load_images(image_paths)
#images = torch.stack(images).to(device)
images, labels = next(iter(dataloader))
images = images[labels == 1, ...]
images, raw_images = load_images(images)
images = images.to(device)
"""
Common usage:
1. Wrap your model with visualization classes defined in grad_cam.py
2. Run forward() with images
3. Run backward() with a list of specific classes
4. Run generate() to export results
"""
# =========================================================================
print("Vanilla Backpropagation:")
bp = BackPropagation(model=model)
probs, ids = bp.forward(images) # sorted
for i in range(topk):
bp.backward(ids=ids[:, [i]])
gradients = bp.generate()
'''
# Save results as image files
for j in range(len(images)):
print("\t#{}: {} ({:.5f})".format(j, classes[ids[j, i]], probs[j, i]))
save_gradient(
filename=osp.join(
output_dir,
"{}-{}-vanilla-{}.png".format(j, arch, classes[ids[j, i]]),
),
gradient=gradients[j],
)
'''
# Remove all the hook function in the "model"
bp.remove_hook()
# =========================================================================
print("Grad-CAM/Guided Backpropagation/Guided Grad-CAM:")
gcam = GradCAM(model=model)
_ = gcam.forward(images)
for i in range(topk):
# Grad-CAM
gcam.backward(ids=ids[:, [i]])
regions = gcam.generate(target_layer=target_layer)
for j in range(len(images)):
print("\t#{}: {} ({:.5f})".format(j, classes[ids[j, i]], probs[j, i]))
# Grad-CAM
'''
save_gradcam(
filename=osp.join(
output_dir,
"{}-{}-gradcam-{}-{}.png".format(
j, arch, target_layer, classes[ids[j, i]]
),
),
gcam=regions[j, 0],
raw_image=raw_images[j],
)
'''
save_gradcam(
filename=osp.join(
output_dir,
"{}-{}-gradcam-{}-{}.png".format(
j, arch, target_layer, classes[ids[j, i]]
),
),
gcam=regions[j, 0],
raw_image=raw_images[j],
)
# Guided Grad-CAM
'''
save_gradient(
filename=osp.join(
output_dir,
"{}-{}-guided_gradcam-{}-{}.png".format(
j, arch, target_layer, classes[ids[j, i]]
),
),
gradient=torch.mul(regions, gradients)[j],
)
'''