--- a
+++ b/visualization/utils.py
@@ -0,0 +1,237 @@
+import cv2
+import os
+import numpy as np
+import torch
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+
+class ActivationsAndGradients:
+    """ Class for extracting activations and
+    registering gradients from targeted intermediate layers """
+
+    def __init__(self, model, target_layers, reshape_transform):
+        self.model = model
+        self.gradients = []
+        self.activations = []
+        self.reshape_transform = reshape_transform
+        self.handles = []
+        for target_layer in target_layers:
+            self.handles.append(
+                target_layer.register_forward_hook(
+                    self.save_activation))
+            # Backward compatibility with older pytorch versions:
+            if hasattr(target_layer, 'register_full_backward_hook'):
+                self.handles.append(
+                    target_layer.register_full_backward_hook(
+                        self.save_gradient))
+            else:
+                self.handles.append(
+                    target_layer.register_backward_hook(
+                        self.save_gradient))
+
+    def save_activation(self, module, input, output):
+        activation = output
+        if self.reshape_transform is not None:
+            activation = self.reshape_transform(activation)
+        self.activations.append(activation.cpu().detach())
+
+    def save_gradient(self, module, grad_input, grad_output):
+        # Gradients are computed in reverse order
+        grad = grad_output[0]
+        if self.reshape_transform is not None:
+            grad = self.reshape_transform(grad)
+        self.gradients = [grad.cpu().detach()] + self.gradients
+
+    def __call__(self, x):
+        self.gradients = []
+        self.activations = []
+        return self.model(x)
+
+    def release(self):
+        for handle in self.handles:
+            handle.remove()
+
+
+class GradCAM:
+    def __init__(self,
+                 model,
+                 target_layers,
+                 reshape_transform=None,
+                 use_cuda=False):
+        self.model = model.eval()
+        self.target_layers = target_layers
+        self.reshape_transform = reshape_transform
+        self.cuda = use_cuda
+        if self.cuda:
+            self.model = model.cuda()
+        self.activations_and_grads = ActivationsAndGradients(
+            self.model, target_layers, reshape_transform)
+
+    """ Get a vector of weights for every channel in the target layer.
+        Methods that return weights channels,
+        will typically need to only implement this function. """
+
+    @staticmethod
+    def get_cam_weights(grads):
+        return np.mean(grads, axis=(2, 3), keepdims=True)
+
+    @staticmethod
+    def get_loss(output, target_category):
+        loss = 0
+        for i in range(len(target_category)):
+            loss = loss + output[i, target_category[i]]
+        return loss
+
+    def get_cam_image(self, activations, grads):
+        weights = self.get_cam_weights(grads)
+        weighted_activations = weights * activations
+        cam = weighted_activations.sum(axis=1)
+
+        return cam
+
+    @staticmethod
+    def get_target_width_height(input_tensor):
+        width, height = input_tensor.size(-1), input_tensor.size(-2)
+        return width, height
+
+    def compute_cam_per_layer(self, input_tensor):
+        activations_list = [a.cpu().data.numpy()
+                            for a in self.activations_and_grads.activations]
+        grads_list = [g.cpu().data.numpy()
+                      for g in self.activations_and_grads.gradients]
+        target_size = self.get_target_width_height(input_tensor)
+
+        cam_per_target_layer = []
+        # Loop over the saliency image from every layer
+
+        for layer_activations, layer_grads in zip(activations_list, grads_list):
+            cam = self.get_cam_image(layer_activations, layer_grads)
+            cam[cam < 0] = 0  # works like mute the min-max scale in the function of scale_cam_image
+            scaled = self.scale_cam_image(cam, target_size)
+            cam_per_target_layer.append(scaled[:, None, :])
+
+        return cam_per_target_layer
+
+    def aggregate_multi_layers(self, cam_per_target_layer):
+        cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
+        cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
+        result = np.mean(cam_per_target_layer, axis=1)
+        return self.scale_cam_image(result)
+
+    @staticmethod
+    def scale_cam_image(cam, target_size=None):
+        result = []
+        for img in cam:
+            img = img - np.min(img)
+            img = img / (1e-7 + np.max(img))
+            if target_size is not None:
+                img = cv2.resize(img, target_size)
+            result.append(img)
+        result = np.float32(result)
+
+        return result
+
+    def __call__(self, input_tensor, target_category=None):
+
+        if self.cuda:
+            input_tensor = input_tensor.cuda()
+
+        # 正向传播得到网络输出logits(未经过softmax)
+        output = self.activations_and_grads(input_tensor)
+        if isinstance(target_category, int):
+            target_category = [target_category] * input_tensor.size(0)
+
+        if target_category is None:
+            # ! 这里进行了修改 output[0]是输入的尺寸,output[1]是输出的尺寸?
+            output=output[1]
+            target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
+            print(f"category id: {target_category}")
+        else:
+            assert (len(target_category) == input_tensor.size(0))
+
+        self.model.zero_grad()
+        loss = self.get_loss(output, target_category)
+        print('the loss is', loss)
+        loss.backward(retain_graph=True)
+        # loss.backward(torch.ones_like(output), retain_graph=True)
+
+        # In most of the saliency attribution papers, the saliency is
+        # computed with a single target layer.
+        # Commonly it is the last convolutional layer.
+        # Here we support passing a list with multiple target layers.
+        # It will compute the saliency image for every image,
+        # and then aggregate them (with a default mean aggregation).
+        # This gives you more flexibility in case you just want to
+        # use all conv layers for example, all Batchnorm layers,
+        # or something else.
+        cam_per_layer = self.compute_cam_per_layer(input_tensor)
+        return self.aggregate_multi_layers(cam_per_layer)
+
+    def __del__(self):
+        self.activations_and_grads.release()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        self.activations_and_grads.release()
+        if isinstance(exc_value, IndexError):
+            # Handle IndexError here...
+            print(
+                f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
+            return True
+
+
+def show_cam_on_image(img: np.ndarray,
+                      mask: np.ndarray,
+                      use_rgb: bool = False,
+                      colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
+    """ This function overlays the cam mask on the image as an heatmap.
+    By default the heatmap is in BGR format.
+
+    :param img: The base image in RGB or BGR format.
+    :param mask: The cam mask.
+    :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
+    :param colormap: The OpenCV colormap to be used.
+    :returns: The default image with the cam overlay.
+    """
+
+    heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
+    if use_rgb:
+        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
+    heatmap = np.float32(heatmap) / 255
+
+    if np.max(img) > 1:
+        raise Exception(
+            "The input image should np.float32 in the range [0, 1]")
+
+    cam = heatmap  # + img
+    cam = cam / np.max(cam)
+    return np.uint8(255 * cam)
+
+
+def center_crop_img(img: np.ndarray, size: int):
+    h, w, c = img.shape
+
+    if w == h == size:
+        return img
+
+    if w < h:
+        ratio = size / w
+        new_w = size
+        new_h = int(h * ratio)
+    else:
+        ratio = size / h
+        new_h = size
+        new_w = int(w * ratio)
+
+    img = cv2.resize(img, dsize=(new_w, new_h))
+
+    if new_w == size:
+        h = (new_h - size) // 2
+        img = img[h: h+size]
+    else:
+        w = (new_w - size) // 2
+        img = img[:, w: w+size]
+
+    return img
\ No newline at end of file