Diff of /visualization/utils.py [000000] .. [8bbec7]

Switch to unified view

a b/visualization/utils.py
1
import cv2
2
import os
3
import numpy as np
4
import torch
5
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
6
7
8
class ActivationsAndGradients:
9
    """ Class for extracting activations and
10
    registering gradients from targeted intermediate layers """
11
12
    def __init__(self, model, target_layers, reshape_transform):
13
        self.model = model
14
        self.gradients = []
15
        self.activations = []
16
        self.reshape_transform = reshape_transform
17
        self.handles = []
18
        for target_layer in target_layers:
19
            self.handles.append(
20
                target_layer.register_forward_hook(
21
                    self.save_activation))
22
            # Backward compatibility with older pytorch versions:
23
            if hasattr(target_layer, 'register_full_backward_hook'):
24
                self.handles.append(
25
                    target_layer.register_full_backward_hook(
26
                        self.save_gradient))
27
            else:
28
                self.handles.append(
29
                    target_layer.register_backward_hook(
30
                        self.save_gradient))
31
32
    def save_activation(self, module, input, output):
33
        activation = output
34
        if self.reshape_transform is not None:
35
            activation = self.reshape_transform(activation)
36
        self.activations.append(activation.cpu().detach())
37
38
    def save_gradient(self, module, grad_input, grad_output):
39
        # Gradients are computed in reverse order
40
        grad = grad_output[0]
41
        if self.reshape_transform is not None:
42
            grad = self.reshape_transform(grad)
43
        self.gradients = [grad.cpu().detach()] + self.gradients
44
45
    def __call__(self, x):
46
        self.gradients = []
47
        self.activations = []
48
        return self.model(x)
49
50
    def release(self):
51
        for handle in self.handles:
52
            handle.remove()
53
54
55
class GradCAM:
56
    def __init__(self,
57
                 model,
58
                 target_layers,
59
                 reshape_transform=None,
60
                 use_cuda=False):
61
        self.model = model.eval()
62
        self.target_layers = target_layers
63
        self.reshape_transform = reshape_transform
64
        self.cuda = use_cuda
65
        if self.cuda:
66
            self.model = model.cuda()
67
        self.activations_and_grads = ActivationsAndGradients(
68
            self.model, target_layers, reshape_transform)
69
70
    """ Get a vector of weights for every channel in the target layer.
71
        Methods that return weights channels,
72
        will typically need to only implement this function. """
73
74
    @staticmethod
75
    def get_cam_weights(grads):
76
        return np.mean(grads, axis=(2, 3), keepdims=True)
77
78
    @staticmethod
79
    def get_loss(output, target_category):
80
        loss = 0
81
        for i in range(len(target_category)):
82
            loss = loss + output[i, target_category[i]]
83
        return loss
84
85
    def get_cam_image(self, activations, grads):
86
        weights = self.get_cam_weights(grads)
87
        weighted_activations = weights * activations
88
        cam = weighted_activations.sum(axis=1)
89
90
        return cam
91
92
    @staticmethod
93
    def get_target_width_height(input_tensor):
94
        width, height = input_tensor.size(-1), input_tensor.size(-2)
95
        return width, height
96
97
    def compute_cam_per_layer(self, input_tensor):
98
        activations_list = [a.cpu().data.numpy()
99
                            for a in self.activations_and_grads.activations]
100
        grads_list = [g.cpu().data.numpy()
101
                      for g in self.activations_and_grads.gradients]
102
        target_size = self.get_target_width_height(input_tensor)
103
104
        cam_per_target_layer = []
105
        # Loop over the saliency image from every layer
106
107
        for layer_activations, layer_grads in zip(activations_list, grads_list):
108
            cam = self.get_cam_image(layer_activations, layer_grads)
109
            cam[cam < 0] = 0  # works like mute the min-max scale in the function of scale_cam_image
110
            scaled = self.scale_cam_image(cam, target_size)
111
            cam_per_target_layer.append(scaled[:, None, :])
112
113
        return cam_per_target_layer
114
115
    def aggregate_multi_layers(self, cam_per_target_layer):
116
        cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
117
        cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
118
        result = np.mean(cam_per_target_layer, axis=1)
119
        return self.scale_cam_image(result)
120
121
    @staticmethod
122
    def scale_cam_image(cam, target_size=None):
123
        result = []
124
        for img in cam:
125
            img = img - np.min(img)
126
            img = img / (1e-7 + np.max(img))
127
            if target_size is not None:
128
                img = cv2.resize(img, target_size)
129
            result.append(img)
130
        result = np.float32(result)
131
132
        return result
133
134
    def __call__(self, input_tensor, target_category=None):
135
136
        if self.cuda:
137
            input_tensor = input_tensor.cuda()
138
139
        # 正向传播得到网络输出logits(未经过softmax)
140
        output = self.activations_and_grads(input_tensor)
141
        if isinstance(target_category, int):
142
            target_category = [target_category] * input_tensor.size(0)
143
144
        if target_category is None:
145
            # ! 这里进行了修改 output[0]是输入的尺寸,output[1]是输出的尺寸?
146
            output=output[1]
147
            target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
148
            print(f"category id: {target_category}")
149
        else:
150
            assert (len(target_category) == input_tensor.size(0))
151
152
        self.model.zero_grad()
153
        loss = self.get_loss(output, target_category)
154
        print('the loss is', loss)
155
        loss.backward(retain_graph=True)
156
        # loss.backward(torch.ones_like(output), retain_graph=True)
157
158
        # In most of the saliency attribution papers, the saliency is
159
        # computed with a single target layer.
160
        # Commonly it is the last convolutional layer.
161
        # Here we support passing a list with multiple target layers.
162
        # It will compute the saliency image for every image,
163
        # and then aggregate them (with a default mean aggregation).
164
        # This gives you more flexibility in case you just want to
165
        # use all conv layers for example, all Batchnorm layers,
166
        # or something else.
167
        cam_per_layer = self.compute_cam_per_layer(input_tensor)
168
        return self.aggregate_multi_layers(cam_per_layer)
169
170
    def __del__(self):
171
        self.activations_and_grads.release()
172
173
    def __enter__(self):
174
        return self
175
176
    def __exit__(self, exc_type, exc_value, exc_tb):
177
        self.activations_and_grads.release()
178
        if isinstance(exc_value, IndexError):
179
            # Handle IndexError here...
180
            print(
181
                f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
182
            return True
183
184
185
def show_cam_on_image(img: np.ndarray,
186
                      mask: np.ndarray,
187
                      use_rgb: bool = False,
188
                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
189
    """ This function overlays the cam mask on the image as an heatmap.
190
    By default the heatmap is in BGR format.
191
192
    :param img: The base image in RGB or BGR format.
193
    :param mask: The cam mask.
194
    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
195
    :param colormap: The OpenCV colormap to be used.
196
    :returns: The default image with the cam overlay.
197
    """
198
199
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
200
    if use_rgb:
201
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
202
    heatmap = np.float32(heatmap) / 255
203
204
    if np.max(img) > 1:
205
        raise Exception(
206
            "The input image should np.float32 in the range [0, 1]")
207
208
    cam = heatmap  # + img
209
    cam = cam / np.max(cam)
210
    return np.uint8(255 * cam)
211
212
213
def center_crop_img(img: np.ndarray, size: int):
214
    h, w, c = img.shape
215
216
    if w == h == size:
217
        return img
218
219
    if w < h:
220
        ratio = size / w
221
        new_w = size
222
        new_h = int(h * ratio)
223
    else:
224
        ratio = size / h
225
        new_h = size
226
        new_w = int(w * ratio)
227
228
    img = cv2.resize(img, dsize=(new_w, new_h))
229
230
    if new_w == size:
231
        h = (new_h - size) // 2
232
        img = img[h: h+size]
233
    else:
234
        w = (new_w - size) // 2
235
        img = img[:, w: w+size]
236
237
    return img