|
a |
|
b/gcam.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from networks.RotCAtt_TransUNet_plusplus_gradcam import RotCAtt_TransUNet_plusplus_GradCam |
|
|
4 |
import SimpleITK as sitk |
|
|
5 |
import numpy as np |
|
|
6 |
import cv2 |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
|
|
|
9 |
from networks.dense_feature_extraction import Dense |
|
|
10 |
from networks.linear_embedding import LinearEmbedding |
|
|
11 |
from networks.transformer import Transformer |
|
|
12 |
from networks.rotatory_attention import RotatoryAttention |
|
|
13 |
from networks.recon import Reconstruction |
|
|
14 |
from networks.uct_decoder import UCTDecoder |
|
|
15 |
from networks.config import get_config |
|
|
16 |
|
|
|
17 |
class RotModel(nn.Module): |
|
|
18 |
def __init__(self): |
|
|
19 |
super().__init__() |
|
|
20 |
self.model_path = 'outputs/RotCAtt_TransUNet_plusplus/VHSCDD_RotCAtt_TransUNet_plusplus_bs6_ps16_epo600_hw512_ly4/model.pth' |
|
|
21 |
self.trained_model = torch.load(self.model_path) |
|
|
22 |
self.config = get_config() |
|
|
23 |
self.dense = Dense(self.config).cuda() |
|
|
24 |
self.linear_embedding = LinearEmbedding(self.config).cuda() |
|
|
25 |
self.transformer = Transformer(self.config).cuda() |
|
|
26 |
self.rotatory_attention = RotatoryAttention(self.config).cuda() |
|
|
27 |
self.reconstruct = Reconstruction(self.config).cuda() |
|
|
28 |
self.decoder = UCTDecoder(self.config).cuda() |
|
|
29 |
self.out = nn.Conv2d(self.config.df[0], self.config.num_classes, kernel_size=(1,1), stride=(1,1)).cuda() |
|
|
30 |
|
|
|
31 |
# define state dict |
|
|
32 |
dense_state_dict = self.dense.state_dict() |
|
|
33 |
embedding_state_dict = self.linear_embedding.state_dict() |
|
|
34 |
transformer_state_dict = self.transformer.state_dict() |
|
|
35 |
rot_state_dict = self.rotatory_attention.state_dict() |
|
|
36 |
recon_state_dict = self.reconstruct.state_dict() |
|
|
37 |
decoder_state_dict = self.decoder.state_dict() |
|
|
38 |
out_state_dict = self.out.state_dict() |
|
|
39 |
|
|
|
40 |
for name, param in self.trained_model.state_dict().items(): |
|
|
41 |
if name.startswith('dense'): |
|
|
42 |
dense_state_dict[name[len("dense."):]].copy_(param) |
|
|
43 |
elif name.startswith('linear_embedding'): |
|
|
44 |
embedding_state_dict[name[len("linear_embedding."):]].copy_(param) |
|
|
45 |
elif name.startswith('transformer'): |
|
|
46 |
transformer_state_dict[name[len('transformer.'):]].copy_(param) |
|
|
47 |
elif name.startswith('rotatory_attention'): |
|
|
48 |
rot_state_dict[name[len('rotatory_attention.'):]].copy_(param) |
|
|
49 |
elif name.startswith('reconstruct'): |
|
|
50 |
recon_state_dict[name[len('reconstruct.'):]].copy_(param) |
|
|
51 |
elif name.startswith('decoder'): |
|
|
52 |
decoder_state_dict[name[len('decoder.'):]].copy_(param) |
|
|
53 |
elif name.startswith('out'): |
|
|
54 |
out_state_dict[name[len('out.'):]].copy_(param) |
|
|
55 |
|
|
|
56 |
self.dense.eval() |
|
|
57 |
self.linear_embedding.eval() |
|
|
58 |
self.transformer.eval() |
|
|
59 |
self.rotatory_attention.eval() |
|
|
60 |
self.rotatory_attention.eval() |
|
|
61 |
self.reconstruct.eval() |
|
|
62 |
self.decoder.eval() |
|
|
63 |
self.out.eval() |
|
|
64 |
self.gradients = [] |
|
|
65 |
|
|
|
66 |
def activations_hook(self, grad): |
|
|
67 |
self.gradients.append(grad) |
|
|
68 |
|
|
|
69 |
def get_activations_gradient(self): |
|
|
70 |
return self.gradients |
|
|
71 |
|
|
|
72 |
def clear_activations_gradient(self): |
|
|
73 |
self.gradients.clear() |
|
|
74 |
|
|
|
75 |
def get_activations(self, x): |
|
|
76 |
x1, x2, x3, x4 = self.dense(x) |
|
|
77 |
z1, z2, z3 = self.linear_embedding(x1, x2, x3) |
|
|
78 |
e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3) |
|
|
79 |
r1, r2, r3 = self.rotatory_attention(z1, z2, z3) |
|
|
80 |
|
|
|
81 |
f1 = e1 + r1 |
|
|
82 |
f2 = e2 + r2 |
|
|
83 |
f3 = e3 + r3 |
|
|
84 |
|
|
|
85 |
o1, o2, o3 = self.reconstruct(f1, f2, f3) |
|
|
86 |
y = self.decoder(o1, o2, o3, x4) |
|
|
87 |
y.register_hook(self.activations_hook) |
|
|
88 |
return self.out(y), y |
|
|
89 |
|
|
|
90 |
def forward(self, x): |
|
|
91 |
x1, x2, x3, x4 = self.dense(x) |
|
|
92 |
z1, z2, z3 = self.linear_embedding(x1, x2, x3) |
|
|
93 |
e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3) |
|
|
94 |
r1, r2, r3 = self.rotatory_attention(z1, z2, z3) |
|
|
95 |
|
|
|
96 |
f1 = e1 + r1 |
|
|
97 |
f2 = e2 + r2 |
|
|
98 |
f3 = e3 + r3 |
|
|
99 |
|
|
|
100 |
o1, o2, o3 = self.reconstruct(f1, f2, f3) |
|
|
101 |
y = self.decoder(o1, o2, o3, x4) |
|
|
102 |
|
|
|
103 |
return self.out(y) |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def seg_gradcam(model, image_path, index_list: list, instance_list: list, colormap=cv2.COLORMAP_JET, img_size=(512,512)): |
|
|
107 |
|
|
|
108 |
assert len(index_list) == len(instance_list), print( |
|
|
109 |
"Length of list of indices is not equal to length of list of instances") |
|
|
110 |
|
|
|
111 |
name_dict = { |
|
|
112 |
0: "background", |
|
|
113 |
1: "left_ventricle", |
|
|
114 |
2: "right_ventricle", |
|
|
115 |
3: "left_atrium", |
|
|
116 |
4: "right_atrium", |
|
|
117 |
5: "myocardium", |
|
|
118 |
6: "descending_aeorta", |
|
|
119 |
7: "pulmonary_trunk", |
|
|
120 |
8: "ascending_aorta", |
|
|
121 |
9: "vena_cava", |
|
|
122 |
10: "auricle", |
|
|
123 |
11: "coronary_artery", |
|
|
124 |
} |
|
|
125 |
|
|
|
126 |
num_slice = len(index_list) |
|
|
127 |
image = sitk.GetArrayFromImage(sitk.ReadImage(image_path, sitk.sitkFloat32)) |
|
|
128 |
|
|
|
129 |
output = np.ones((num_slice, 512, 512), dtype=np.float32) |
|
|
130 |
slice_list = np.array([image[index] for index in index_list]) |
|
|
131 |
|
|
|
132 |
fig, axes = plt.subplots(nrows=2, ncols=num_slice, figsize=(16,16)) |
|
|
133 |
for i in range(num_slice): output[i,:,:] = slice_list[i] |
|
|
134 |
|
|
|
135 |
input = torch.from_numpy(output).unsqueeze(1).cuda() |
|
|
136 |
output, activations = model.get_activations(input) |
|
|
137 |
activations = activations.detach() |
|
|
138 |
|
|
|
139 |
# calculate score |
|
|
140 |
for x in range(num_slice): |
|
|
141 |
class_output = output[x, instance_list[x]] |
|
|
142 |
class_score_sum = class_output.sum() |
|
|
143 |
class_score_sum.backward(retain_graph=True) |
|
|
144 |
|
|
|
145 |
gradients = model.get_activations_gradient() |
|
|
146 |
print(f"Length gradients: {len(gradients)}") |
|
|
147 |
gradients = gradients[-1] |
|
|
148 |
model.clear_activations_gradient() |
|
|
149 |
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) |
|
|
150 |
|
|
|
151 |
print(f"Pooled Gradient: {pooled_gradients.shape}") |
|
|
152 |
|
|
|
153 |
instance_activation = activations[x] |
|
|
154 |
|
|
|
155 |
for channel in range(64): |
|
|
156 |
instance_activation[channel, :, :] *= pooled_gradients[channel] |
|
|
157 |
|
|
|
158 |
print(f"Activations: {instance_activation.shape}") |
|
|
159 |
|
|
|
160 |
heatmap = torch.mean(instance_activation, dim=0) |
|
|
161 |
print(f"Heatmap shape: {heatmap.shape}") |
|
|
162 |
|
|
|
163 |
heatmap /= torch.max(heatmap) |
|
|
164 |
heatmap = heatmap.detach().cpu().numpy() |
|
|
165 |
heatmap = np.maximum(heatmap, 0) |
|
|
166 |
|
|
|
167 |
heatmap = cv2.resize(heatmap, img_size) |
|
|
168 |
heatmap = np.uint8(255 * heatmap) |
|
|
169 |
heatmap = cv2.applyColorMap(heatmap, colormap) |
|
|
170 |
|
|
|
171 |
image0 = torch.squeeze(output[x, :, :, :]) |
|
|
172 |
image0 = image0.detach().cpu().numpy() |
|
|
173 |
image0 = np.stack((image0,) * 3, axis=-1) |
|
|
174 |
image0 = cv2.cvtColor(image0, cv2.COLOR_GRAY2BGR) |
|
|
175 |
|
|
|
176 |
superimposed_img = (heatmap / 255.0) * 0.6 + image0 |
|
|
177 |
|
|
|
178 |
heatmap_plot = axes[0][x].imshow(superimposed_img, cmap='RdBu') |
|
|
179 |
axes[0][x].set_xlabel( |
|
|
180 |
f"Slice: {index_list[x]} || Class: {name_dict[instance_list[x]]}") |
|
|
181 |
fig.colorbar(heatmap_plot, ax=axes[0][x], fraction=0.046) |
|
|
182 |
|
|
|
183 |
superimposed_plot = axes[1][x].imshow(heatmap, cmap='RdBu') |
|
|
184 |
axes[1][x].set_xlabel(f"Heatmap of slice: {index_list[x]}") |
|
|
185 |
|
|
|
186 |
fig.colorbar(superimposed_plot, |
|
|
187 |
ax=axes[1][x], fraction=0.046) |
|
|
188 |
|
|
|
189 |
title_text = "" |
|
|
190 |
for i, x in enumerate(index_list): |
|
|
191 |
title_text += str(x) |
|
|
192 |
|
|
|
193 |
if i == len(index_list) - 1: |
|
|
194 |
pass |
|
|
195 |
else: |
|
|
196 |
title_text += ", " |
|
|
197 |
|
|
|
198 |
plt.suptitle( |
|
|
199 |
f"Model's focus on slices number: {title_text}", fontsize=16) |
|
|
200 |
plt.subplots_adjust(wspace=0.5) |
|
|
201 |
plt.show() |
|
|
202 |
|
|
|
203 |
|
|
|
204 |
|
|
|
205 |
if __name__ == "__main__": |
|
|
206 |
image_path = 'data/VHSCDD_512/test_images/0001.nii.gz' |
|
|
207 |
model = RotModel() |
|
|
208 |
|
|
|
209 |
num_slice = 4 |
|
|
210 |
class_instance = 6 |
|
|
211 |
index_list = [125, 156, 153, 180] |
|
|
212 |
instance_list = [3, 5, 11, 6] |
|
|
213 |
|
|
|
214 |
seg_gradcam ( |
|
|
215 |
model=model, |
|
|
216 |
image_path=image_path, |
|
|
217 |
index_list=index_list, |
|
|
218 |
instance_list=instance_list |
|
|
219 |
) |