|
a |
|
b/app/backend/Utils.py |
|
|
1 |
#############################GRAD-Cam############################# |
|
|
2 |
|
|
|
3 |
#@title Grad-CAM code with full-credit to Jimin Tan (https://github.com/tanjimin/grad-cam-pytorch-light) |
|
|
4 |
import torch ,cv2 |
|
|
5 |
from PIL import Image |
|
|
6 |
import numpy as np |
|
|
7 |
import os |
|
|
8 |
class InfoHolder(): |
|
|
9 |
|
|
|
10 |
def __init__(self, heatmap_layer): |
|
|
11 |
self.gradient = None |
|
|
12 |
self.activation = None |
|
|
13 |
self.heatmap_layer = heatmap_layer |
|
|
14 |
|
|
|
15 |
def get_gradient(self, grad): |
|
|
16 |
self.gradient = grad |
|
|
17 |
|
|
|
18 |
def hook(self, model, input, output): |
|
|
19 |
if output.requires_grad: |
|
|
20 |
output.register_hook(self.get_gradient) |
|
|
21 |
self.activation = output.detach() |
|
|
22 |
|
|
|
23 |
def generate_heatmap(weighted_activation): |
|
|
24 |
raw_heatmap = torch.mean(weighted_activation, 0) |
|
|
25 |
heatmap = np.maximum(raw_heatmap.detach().cpu(), 0) |
|
|
26 |
heatmap /= torch.max(heatmap) + 1e-10 |
|
|
27 |
return heatmap.numpy() |
|
|
28 |
|
|
|
29 |
def superimpose(input_img, heatmap): |
|
|
30 |
img = to_RGB(input_img) |
|
|
31 |
heatmap = cv2.resize(heatmap, (img.shape[0], img.shape[1])) |
|
|
32 |
heatmap = np.uint8(255 * heatmap) |
|
|
33 |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
34 |
superimposed_img = np.uint8(heatmap * 0.6 + img * 0.4) |
|
|
35 |
pil_img = cv2.cvtColor(superimposed_img,cv2.COLOR_BGR2RGB) |
|
|
36 |
return pil_img |
|
|
37 |
|
|
|
38 |
def to_RGB(tensor): |
|
|
39 |
tensor = (tensor - tensor.min()) |
|
|
40 |
tensor = tensor/(tensor.max() + 1e-10) |
|
|
41 |
image_binary = np.transpose(tensor.numpy(), (1, 2, 0)) |
|
|
42 |
image = np.uint8(255 * image_binary) |
|
|
43 |
return image |
|
|
44 |
|
|
|
45 |
def grad_cam(model, input_tensor, heatmap_layer, truelabel=None): |
|
|
46 |
|
|
|
47 |
info = InfoHolder(heatmap_layer) |
|
|
48 |
heatmap_layer.register_forward_hook(info.hook) |
|
|
49 |
|
|
|
50 |
output = model(input_tensor.unsqueeze(0))[0] |
|
|
51 |
truelabel = truelabel if truelabel else torch.argmax(output) |
|
|
52 |
|
|
|
53 |
output[truelabel].backward() |
|
|
54 |
|
|
|
55 |
weights = torch.mean(info.gradient, [0, 2, 3]) |
|
|
56 |
activation = info.activation.squeeze(0) |
|
|
57 |
|
|
|
58 |
weighted_activation = torch.zeros(activation.shape) |
|
|
59 |
for idx, (weight, activation) in enumerate(zip(weights, activation)): |
|
|
60 |
weighted_activation[idx] = weight * activation |
|
|
61 |
heatmap = generate_heatmap(weighted_activation) |
|
|
62 |
return superimpose(input_tensor, heatmap) |
|
|
63 |
|
|
|
64 |
def use_gradcam(img_path,dest_path,model,transforms): |
|
|
65 |
image=Image.open(img_path).convert('RGB') |
|
|
66 |
layer4=model[0][-1] |
|
|
67 |
heatmap_layer=layer4[2].conv2 |
|
|
68 |
input_tensor=transforms(image) |
|
|
69 |
|
|
|
70 |
#get filename without extension |
|
|
71 |
filename=os.path.splitext(os.path.basename(img_path))[0] |
|
|
72 |
grad_cam_image= grad_cam(model, input_tensor, heatmap_layer) |
|
|
73 |
return cv2.imwrite(os.path.join(dest_path,'(gradcam)'+filename+'.png'),grad_cam_image) |