|
a |
|
b/visualization/utils.py |
|
|
1 |
import cv2 |
|
|
2 |
import os |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class ActivationsAndGradients: |
|
|
9 |
""" Class for extracting activations and |
|
|
10 |
registering gradients from targeted intermediate layers """ |
|
|
11 |
|
|
|
12 |
def __init__(self, model, target_layers, reshape_transform): |
|
|
13 |
self.model = model |
|
|
14 |
self.gradients = [] |
|
|
15 |
self.activations = [] |
|
|
16 |
self.reshape_transform = reshape_transform |
|
|
17 |
self.handles = [] |
|
|
18 |
for target_layer in target_layers: |
|
|
19 |
self.handles.append( |
|
|
20 |
target_layer.register_forward_hook( |
|
|
21 |
self.save_activation)) |
|
|
22 |
# Backward compatibility with older pytorch versions: |
|
|
23 |
if hasattr(target_layer, 'register_full_backward_hook'): |
|
|
24 |
self.handles.append( |
|
|
25 |
target_layer.register_full_backward_hook( |
|
|
26 |
self.save_gradient)) |
|
|
27 |
else: |
|
|
28 |
self.handles.append( |
|
|
29 |
target_layer.register_backward_hook( |
|
|
30 |
self.save_gradient)) |
|
|
31 |
|
|
|
32 |
def save_activation(self, module, input, output): |
|
|
33 |
activation = output |
|
|
34 |
if self.reshape_transform is not None: |
|
|
35 |
activation = self.reshape_transform(activation) |
|
|
36 |
self.activations.append(activation.cpu().detach()) |
|
|
37 |
|
|
|
38 |
def save_gradient(self, module, grad_input, grad_output): |
|
|
39 |
# Gradients are computed in reverse order |
|
|
40 |
grad = grad_output[0] |
|
|
41 |
if self.reshape_transform is not None: |
|
|
42 |
grad = self.reshape_transform(grad) |
|
|
43 |
self.gradients = [grad.cpu().detach()] + self.gradients |
|
|
44 |
|
|
|
45 |
def __call__(self, x): |
|
|
46 |
self.gradients = [] |
|
|
47 |
self.activations = [] |
|
|
48 |
return self.model(x) |
|
|
49 |
|
|
|
50 |
def release(self): |
|
|
51 |
for handle in self.handles: |
|
|
52 |
handle.remove() |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
class GradCAM: |
|
|
56 |
def __init__(self, |
|
|
57 |
model, |
|
|
58 |
target_layers, |
|
|
59 |
reshape_transform=None, |
|
|
60 |
use_cuda=False): |
|
|
61 |
self.model = model.eval() |
|
|
62 |
self.target_layers = target_layers |
|
|
63 |
self.reshape_transform = reshape_transform |
|
|
64 |
self.cuda = use_cuda |
|
|
65 |
if self.cuda: |
|
|
66 |
self.model = model.cuda() |
|
|
67 |
self.activations_and_grads = ActivationsAndGradients( |
|
|
68 |
self.model, target_layers, reshape_transform) |
|
|
69 |
|
|
|
70 |
""" Get a vector of weights for every channel in the target layer. |
|
|
71 |
Methods that return weights channels, |
|
|
72 |
will typically need to only implement this function. """ |
|
|
73 |
|
|
|
74 |
@staticmethod |
|
|
75 |
def get_cam_weights(grads): |
|
|
76 |
return np.mean(grads, axis=(2, 3), keepdims=True) |
|
|
77 |
|
|
|
78 |
@staticmethod |
|
|
79 |
def get_loss(output, target_category): |
|
|
80 |
loss = 0 |
|
|
81 |
for i in range(len(target_category)): |
|
|
82 |
loss = loss + output[i, target_category[i]] |
|
|
83 |
return loss |
|
|
84 |
|
|
|
85 |
def get_cam_image(self, activations, grads): |
|
|
86 |
weights = self.get_cam_weights(grads) |
|
|
87 |
weighted_activations = weights * activations |
|
|
88 |
cam = weighted_activations.sum(axis=1) |
|
|
89 |
|
|
|
90 |
return cam |
|
|
91 |
|
|
|
92 |
@staticmethod |
|
|
93 |
def get_target_width_height(input_tensor): |
|
|
94 |
width, height = input_tensor.size(-1), input_tensor.size(-2) |
|
|
95 |
return width, height |
|
|
96 |
|
|
|
97 |
def compute_cam_per_layer(self, input_tensor): |
|
|
98 |
activations_list = [a.cpu().data.numpy() |
|
|
99 |
for a in self.activations_and_grads.activations] |
|
|
100 |
grads_list = [g.cpu().data.numpy() |
|
|
101 |
for g in self.activations_and_grads.gradients] |
|
|
102 |
target_size = self.get_target_width_height(input_tensor) |
|
|
103 |
|
|
|
104 |
cam_per_target_layer = [] |
|
|
105 |
# Loop over the saliency image from every layer |
|
|
106 |
|
|
|
107 |
for layer_activations, layer_grads in zip(activations_list, grads_list): |
|
|
108 |
cam = self.get_cam_image(layer_activations, layer_grads) |
|
|
109 |
cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image |
|
|
110 |
scaled = self.scale_cam_image(cam, target_size) |
|
|
111 |
cam_per_target_layer.append(scaled[:, None, :]) |
|
|
112 |
|
|
|
113 |
return cam_per_target_layer |
|
|
114 |
|
|
|
115 |
def aggregate_multi_layers(self, cam_per_target_layer): |
|
|
116 |
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) |
|
|
117 |
cam_per_target_layer = np.maximum(cam_per_target_layer, 0) |
|
|
118 |
result = np.mean(cam_per_target_layer, axis=1) |
|
|
119 |
return self.scale_cam_image(result) |
|
|
120 |
|
|
|
121 |
@staticmethod |
|
|
122 |
def scale_cam_image(cam, target_size=None): |
|
|
123 |
result = [] |
|
|
124 |
for img in cam: |
|
|
125 |
img = img - np.min(img) |
|
|
126 |
img = img / (1e-7 + np.max(img)) |
|
|
127 |
if target_size is not None: |
|
|
128 |
img = cv2.resize(img, target_size) |
|
|
129 |
result.append(img) |
|
|
130 |
result = np.float32(result) |
|
|
131 |
|
|
|
132 |
return result |
|
|
133 |
|
|
|
134 |
def __call__(self, input_tensor, target_category=None): |
|
|
135 |
|
|
|
136 |
if self.cuda: |
|
|
137 |
input_tensor = input_tensor.cuda() |
|
|
138 |
|
|
|
139 |
# 正向传播得到网络输出logits(未经过softmax) |
|
|
140 |
output = self.activations_and_grads(input_tensor) |
|
|
141 |
if isinstance(target_category, int): |
|
|
142 |
target_category = [target_category] * input_tensor.size(0) |
|
|
143 |
|
|
|
144 |
if target_category is None: |
|
|
145 |
# ! 这里进行了修改 output[0]是输入的尺寸,output[1]是输出的尺寸? |
|
|
146 |
output=output[1] |
|
|
147 |
target_category = np.argmax(output.cpu().data.numpy(), axis=-1) |
|
|
148 |
print(f"category id: {target_category}") |
|
|
149 |
else: |
|
|
150 |
assert (len(target_category) == input_tensor.size(0)) |
|
|
151 |
|
|
|
152 |
self.model.zero_grad() |
|
|
153 |
loss = self.get_loss(output, target_category) |
|
|
154 |
print('the loss is', loss) |
|
|
155 |
loss.backward(retain_graph=True) |
|
|
156 |
# loss.backward(torch.ones_like(output), retain_graph=True) |
|
|
157 |
|
|
|
158 |
# In most of the saliency attribution papers, the saliency is |
|
|
159 |
# computed with a single target layer. |
|
|
160 |
# Commonly it is the last convolutional layer. |
|
|
161 |
# Here we support passing a list with multiple target layers. |
|
|
162 |
# It will compute the saliency image for every image, |
|
|
163 |
# and then aggregate them (with a default mean aggregation). |
|
|
164 |
# This gives you more flexibility in case you just want to |
|
|
165 |
# use all conv layers for example, all Batchnorm layers, |
|
|
166 |
# or something else. |
|
|
167 |
cam_per_layer = self.compute_cam_per_layer(input_tensor) |
|
|
168 |
return self.aggregate_multi_layers(cam_per_layer) |
|
|
169 |
|
|
|
170 |
def __del__(self): |
|
|
171 |
self.activations_and_grads.release() |
|
|
172 |
|
|
|
173 |
def __enter__(self): |
|
|
174 |
return self |
|
|
175 |
|
|
|
176 |
def __exit__(self, exc_type, exc_value, exc_tb): |
|
|
177 |
self.activations_and_grads.release() |
|
|
178 |
if isinstance(exc_value, IndexError): |
|
|
179 |
# Handle IndexError here... |
|
|
180 |
print( |
|
|
181 |
f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") |
|
|
182 |
return True |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
def show_cam_on_image(img: np.ndarray, |
|
|
186 |
mask: np.ndarray, |
|
|
187 |
use_rgb: bool = False, |
|
|
188 |
colormap: int = cv2.COLORMAP_JET) -> np.ndarray: |
|
|
189 |
""" This function overlays the cam mask on the image as an heatmap. |
|
|
190 |
By default the heatmap is in BGR format. |
|
|
191 |
|
|
|
192 |
:param img: The base image in RGB or BGR format. |
|
|
193 |
:param mask: The cam mask. |
|
|
194 |
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. |
|
|
195 |
:param colormap: The OpenCV colormap to be used. |
|
|
196 |
:returns: The default image with the cam overlay. |
|
|
197 |
""" |
|
|
198 |
|
|
|
199 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) |
|
|
200 |
if use_rgb: |
|
|
201 |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
|
202 |
heatmap = np.float32(heatmap) / 255 |
|
|
203 |
|
|
|
204 |
if np.max(img) > 1: |
|
|
205 |
raise Exception( |
|
|
206 |
"The input image should np.float32 in the range [0, 1]") |
|
|
207 |
|
|
|
208 |
cam = heatmap # + img |
|
|
209 |
cam = cam / np.max(cam) |
|
|
210 |
return np.uint8(255 * cam) |
|
|
211 |
|
|
|
212 |
|
|
|
213 |
def center_crop_img(img: np.ndarray, size: int): |
|
|
214 |
h, w, c = img.shape |
|
|
215 |
|
|
|
216 |
if w == h == size: |
|
|
217 |
return img |
|
|
218 |
|
|
|
219 |
if w < h: |
|
|
220 |
ratio = size / w |
|
|
221 |
new_w = size |
|
|
222 |
new_h = int(h * ratio) |
|
|
223 |
else: |
|
|
224 |
ratio = size / h |
|
|
225 |
new_h = size |
|
|
226 |
new_w = int(w * ratio) |
|
|
227 |
|
|
|
228 |
img = cv2.resize(img, dsize=(new_w, new_h)) |
|
|
229 |
|
|
|
230 |
if new_w == size: |
|
|
231 |
h = (new_h - size) // 2 |
|
|
232 |
img = img[h: h+size] |
|
|
233 |
else: |
|
|
234 |
w = (new_w - size) // 2 |
|
|
235 |
img = img[:, w: w+size] |
|
|
236 |
|
|
|
237 |
return img |