|
a |
|
b/minigpt4/common/gradcam.py |
|
|
1 |
import numpy as np |
|
|
2 |
from matplotlib import pyplot as plt |
|
|
3 |
from scipy.ndimage import filters |
|
|
4 |
from skimage import transform as skimage_transform |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def getAttMap(img, attMap, blur=True, overlap=True): |
|
|
8 |
attMap -= attMap.min() |
|
|
9 |
if attMap.max() > 0: |
|
|
10 |
attMap /= attMap.max() |
|
|
11 |
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") |
|
|
12 |
if blur: |
|
|
13 |
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) |
|
|
14 |
attMap -= attMap.min() |
|
|
15 |
attMap /= attMap.max() |
|
|
16 |
cmap = plt.get_cmap("jet") |
|
|
17 |
attMapV = cmap(attMap) |
|
|
18 |
attMapV = np.delete(attMapV, 3, 2) |
|
|
19 |
if overlap: |
|
|
20 |
attMap = ( |
|
|
21 |
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img |
|
|
22 |
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV |
|
|
23 |
) |
|
|
24 |
return attMap |