Diff of /app/backend/Utils.py [000000] .. [69507b]

Switch to unified view

a b/app/backend/Utils.py
1
#############################GRAD-Cam#############################
2
3
#@title Grad-CAM code with full-credit to Jimin Tan (https://github.com/tanjimin/grad-cam-pytorch-light)
4
import torch ,cv2
5
from PIL import Image
6
import numpy as np
7
import os
8
class InfoHolder():
9
10
    def __init__(self, heatmap_layer):
11
        self.gradient = None
12
        self.activation = None
13
        self.heatmap_layer = heatmap_layer
14
15
    def get_gradient(self, grad):
16
        self.gradient = grad
17
18
    def hook(self, model, input, output):
19
        if output.requires_grad:
20
            output.register_hook(self.get_gradient)
21
        self.activation = output.detach()
22
23
def generate_heatmap(weighted_activation):
24
    raw_heatmap = torch.mean(weighted_activation, 0)
25
    heatmap = np.maximum(raw_heatmap.detach().cpu(), 0)
26
    heatmap /= torch.max(heatmap) + 1e-10
27
    return heatmap.numpy()
28
29
def superimpose(input_img, heatmap):
30
    img = to_RGB(input_img)  
31
    heatmap = cv2.resize(heatmap, (img.shape[0], img.shape[1]))
32
    heatmap = np.uint8(255 * heatmap)
33
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
34
    superimposed_img = np.uint8(heatmap * 0.6 + img * 0.4)
35
    pil_img = cv2.cvtColor(superimposed_img,cv2.COLOR_BGR2RGB)
36
    return pil_img
37
38
def to_RGB(tensor):
39
    tensor = (tensor - tensor.min())
40
    tensor = tensor/(tensor.max() + 1e-10)
41
    image_binary = np.transpose(tensor.numpy(), (1, 2, 0))
42
    image = np.uint8(255 * image_binary)
43
    return image
44
45
def grad_cam(model, input_tensor, heatmap_layer, truelabel=None):
46
    
47
    info = InfoHolder(heatmap_layer)
48
    heatmap_layer.register_forward_hook(info.hook)
49
    
50
    output = model(input_tensor.unsqueeze(0))[0]
51
    truelabel = truelabel if truelabel else torch.argmax(output)
52
53
    output[truelabel].backward()
54
55
    weights = torch.mean(info.gradient, [0, 2, 3])
56
    activation = info.activation.squeeze(0)
57
58
    weighted_activation = torch.zeros(activation.shape)
59
    for idx, (weight, activation) in enumerate(zip(weights, activation)):
60
        weighted_activation[idx] = weight * activation
61
    heatmap = generate_heatmap(weighted_activation)
62
    return superimpose(input_tensor, heatmap)
63
64
def use_gradcam(img_path,dest_path,model,transforms):
65
    image=Image.open(img_path).convert('RGB')
66
    layer4=model[0][-1]
67
    heatmap_layer=layer4[2].conv2
68
    input_tensor=transforms(image)
69
70
    #get filename without extension
71
    filename=os.path.splitext(os.path.basename(img_path))[0]
72
    grad_cam_image= grad_cam(model, input_tensor, heatmap_layer)
73
    return cv2.imwrite(os.path.join(dest_path,'(gradcam)'+filename+'.png'),grad_cam_image)