a b/gcam.py
1
import torch
2
import torch.nn as nn
3
from networks.RotCAtt_TransUNet_plusplus_gradcam import RotCAtt_TransUNet_plusplus_GradCam
4
import SimpleITK as sitk
5
import numpy as np
6
import cv2
7
import matplotlib.pyplot as plt
8
9
from networks.dense_feature_extraction import Dense
10
from networks.linear_embedding import LinearEmbedding
11
from networks.transformer import Transformer
12
from networks.rotatory_attention import RotatoryAttention
13
from networks.recon import Reconstruction
14
from networks.uct_decoder import UCTDecoder
15
from networks.config import get_config
16
17
class RotModel(nn.Module):
18
    def __init__(self):
19
        super().__init__()
20
        self.model_path = 'outputs/RotCAtt_TransUNet_plusplus/VHSCDD_RotCAtt_TransUNet_plusplus_bs6_ps16_epo600_hw512_ly4/model.pth'
21
        self.trained_model = torch.load(self.model_path)
22
        self.config = get_config() 
23
        self.dense = Dense(self.config).cuda()
24
        self.linear_embedding = LinearEmbedding(self.config).cuda()
25
        self.transformer = Transformer(self.config).cuda()
26
        self.rotatory_attention = RotatoryAttention(self.config).cuda()
27
        self.reconstruct = Reconstruction(self.config).cuda()
28
        self.decoder = UCTDecoder(self.config).cuda()
29
        self.out = nn.Conv2d(self.config.df[0], self.config.num_classes, kernel_size=(1,1), stride=(1,1)).cuda()
30
31
        # define state dict
32
        dense_state_dict = self.dense.state_dict()
33
        embedding_state_dict = self.linear_embedding.state_dict()
34
        transformer_state_dict = self.transformer.state_dict()
35
        rot_state_dict = self.rotatory_attention.state_dict()
36
        recon_state_dict = self.reconstruct.state_dict()
37
        decoder_state_dict = self.decoder.state_dict()
38
        out_state_dict = self.out.state_dict()  
39
40
        for name, param in self.trained_model.state_dict().items():
41
            if name.startswith('dense'):
42
                dense_state_dict[name[len("dense."):]].copy_(param)
43
            elif name.startswith('linear_embedding'):
44
                embedding_state_dict[name[len("linear_embedding."):]].copy_(param)
45
            elif name.startswith('transformer'):
46
                transformer_state_dict[name[len('transformer.'):]].copy_(param)
47
            elif name.startswith('rotatory_attention'):
48
                rot_state_dict[name[len('rotatory_attention.'):]].copy_(param)
49
            elif name.startswith('reconstruct'):
50
                recon_state_dict[name[len('reconstruct.'):]].copy_(param)
51
            elif name.startswith('decoder'):
52
                decoder_state_dict[name[len('decoder.'):]].copy_(param)
53
            elif name.startswith('out'):
54
                out_state_dict[name[len('out.'):]].copy_(param)
55
56
        self.dense.eval()
57
        self.linear_embedding.eval()
58
        self.transformer.eval()
59
        self.rotatory_attention.eval()
60
        self.rotatory_attention.eval()
61
        self.reconstruct.eval()
62
        self.decoder.eval()
63
        self.out.eval()
64
        self.gradients = []
65
        
66
    def activations_hook(self, grad):
67
        self.gradients.append(grad)
68
69
    def get_activations_gradient(self):
70
        return self.gradients
71
72
    def clear_activations_gradient(self):
73
        self.gradients.clear() 
74
        
75
    def get_activations(self, x):
76
        x1, x2, x3, x4 = self.dense(x)
77
        z1, z2, z3 = self.linear_embedding(x1, x2, x3)
78
        e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3)
79
        r1, r2, r3 = self.rotatory_attention(z1, z2, z3)
80
81
        f1 = e1 + r1
82
        f2 = e2 + r2
83
        f3 = e3 + r3
84
85
        o1, o2, o3 = self.reconstruct(f1, f2, f3)
86
        y = self.decoder(o1, o2, o3, x4)
87
        y.register_hook(self.activations_hook)
88
        return self.out(y), y
89
        
90
    def forward(self, x):
91
        x1, x2, x3, x4 = self.dense(x)
92
        z1, z2, z3 = self.linear_embedding(x1, x2, x3)
93
        e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3)
94
        r1, r2, r3 = self.rotatory_attention(z1, z2, z3)
95
96
        f1 = e1 + r1
97
        f2 = e2 + r2
98
        f3 = e3 + r3
99
        
100
        o1, o2, o3 = self.reconstruct(f1, f2, f3)
101
        y = self.decoder(o1, o2, o3, x4)
102
        
103
        return self.out(y)
104
    
105
    
106
def seg_gradcam(model, image_path, index_list: list, instance_list: list, colormap=cv2.COLORMAP_JET, img_size=(512,512)):
107
    
108
    assert len(index_list) == len(instance_list), print(
109
        "Length of list of indices is not equal to length of list of instances")
110
    
111
    name_dict = {
112
        0: "background",
113
        1: "left_ventricle",
114
        2: "right_ventricle",
115
        3: "left_atrium",
116
        4: "right_atrium",
117
        5:  "myocardium",
118
        6: "descending_aeorta",
119
        7: "pulmonary_trunk",
120
        8: "ascending_aorta",
121
        9: "vena_cava",
122
        10: "auricle",
123
        11: "coronary_artery",
124
    }
125
    
126
    num_slice = len(index_list)
127
    image = sitk.GetArrayFromImage(sitk.ReadImage(image_path, sitk.sitkFloat32))
128
    
129
    output = np.ones((num_slice, 512, 512), dtype=np.float32)
130
    slice_list = np.array([image[index] for index in index_list])
131
    
132
    fig, axes = plt.subplots(nrows=2, ncols=num_slice, figsize=(16,16))
133
    for i in range(num_slice): output[i,:,:] = slice_list[i]
134
    
135
    input = torch.from_numpy(output).unsqueeze(1).cuda()
136
    output, activations = model.get_activations(input)
137
    activations = activations.detach()
138
    
139
    # calculate score
140
    for x in range(num_slice):
141
        class_output = output[x, instance_list[x]]
142
        class_score_sum = class_output.sum()
143
        class_score_sum.backward(retain_graph=True)
144
145
        gradients = model.get_activations_gradient()
146
        print(f"Length gradients: {len(gradients)}")
147
        gradients = gradients[-1]
148
        model.clear_activations_gradient()
149
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
150
151
        print(f"Pooled Gradient: {pooled_gradients.shape}")
152
153
        instance_activation = activations[x]
154
155
        for channel in range(64):
156
            instance_activation[channel, :, :] *= pooled_gradients[channel]
157
158
        print(f"Activations: {instance_activation.shape}")
159
160
        heatmap = torch.mean(instance_activation, dim=0)
161
        print(f"Heatmap shape: {heatmap.shape}")
162
163
        heatmap /= torch.max(heatmap)
164
        heatmap = heatmap.detach().cpu().numpy()
165
        heatmap = np.maximum(heatmap, 0)
166
167
        heatmap = cv2.resize(heatmap, img_size)
168
        heatmap = np.uint8(255 * heatmap)
169
        heatmap = cv2.applyColorMap(heatmap, colormap)
170
171
        image0 = torch.squeeze(output[x, :, :, :])
172
        image0 = image0.detach().cpu().numpy()
173
        image0 = np.stack((image0,) * 3, axis=-1)
174
        image0 = cv2.cvtColor(image0, cv2.COLOR_GRAY2BGR)
175
176
        superimposed_img = (heatmap / 255.0) * 0.6 + image0
177
178
        heatmap_plot = axes[0][x].imshow(superimposed_img, cmap='RdBu')
179
        axes[0][x].set_xlabel(
180
            f"Slice: {index_list[x]} || Class: {name_dict[instance_list[x]]}")
181
        fig.colorbar(heatmap_plot, ax=axes[0][x], fraction=0.046)
182
183
        superimposed_plot = axes[1][x].imshow(heatmap, cmap='RdBu')
184
        axes[1][x].set_xlabel(f"Heatmap of slice: {index_list[x]}")
185
186
        fig.colorbar(superimposed_plot,
187
                     ax=axes[1][x], fraction=0.046)
188
189
    title_text = ""
190
    for i, x in enumerate(index_list):
191
        title_text += str(x)
192
193
        if i == len(index_list) - 1:
194
            pass
195
        else:
196
            title_text += ", "
197
198
    plt.suptitle(
199
        f"Model's focus on slices number: {title_text}", fontsize=16)
200
    plt.subplots_adjust(wspace=0.5)
201
    plt.show()
202
    
203
204
    
205
if __name__ == "__main__":
206
    image_path = 'data/VHSCDD_512/test_images/0001.nii.gz'
207
    model = RotModel()
208
    
209
    num_slice      = 4
210
    class_instance = 6
211
    index_list     = [125, 156, 153,  180]
212
    instance_list  = [3, 5, 11, 6]
213
214
    seg_gradcam (
215
        model=model, 
216
        image_path=image_path, 
217
        index_list=index_list, 
218
        instance_list=instance_list
219
    )