a b/minigpt4/common/gradcam.py
1
import numpy as np
2
from matplotlib import pyplot as plt
3
from scipy.ndimage import filters
4
from skimage import transform as skimage_transform
5
6
7
def getAttMap(img, attMap, blur=True, overlap=True):
8
    attMap -= attMap.min()
9
    if attMap.max() > 0:
10
        attMap /= attMap.max()
11
    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
    if blur:
13
        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
        attMap -= attMap.min()
15
        attMap /= attMap.max()
16
    cmap = plt.get_cmap("jet")
17
    attMapV = cmap(attMap)
18
    attMapV = np.delete(attMapV, 3, 2)
19
    if overlap:
20
        attMap = (
21
            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
        )
24
    return attMap