--- a +++ b/gcam.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +from networks.RotCAtt_TransUNet_plusplus_gradcam import RotCAtt_TransUNet_plusplus_GradCam +import SimpleITK as sitk +import numpy as np +import cv2 +import matplotlib.pyplot as plt + +from networks.dense_feature_extraction import Dense +from networks.linear_embedding import LinearEmbedding +from networks.transformer import Transformer +from networks.rotatory_attention import RotatoryAttention +from networks.recon import Reconstruction +from networks.uct_decoder import UCTDecoder +from networks.config import get_config + +class RotModel(nn.Module): + def __init__(self): + super().__init__() + self.model_path = 'outputs/RotCAtt_TransUNet_plusplus/VHSCDD_RotCAtt_TransUNet_plusplus_bs6_ps16_epo600_hw512_ly4/model.pth' + self.trained_model = torch.load(self.model_path) + self.config = get_config() + self.dense = Dense(self.config).cuda() + self.linear_embedding = LinearEmbedding(self.config).cuda() + self.transformer = Transformer(self.config).cuda() + self.rotatory_attention = RotatoryAttention(self.config).cuda() + self.reconstruct = Reconstruction(self.config).cuda() + self.decoder = UCTDecoder(self.config).cuda() + self.out = nn.Conv2d(self.config.df[0], self.config.num_classes, kernel_size=(1,1), stride=(1,1)).cuda() + + # define state dict + dense_state_dict = self.dense.state_dict() + embedding_state_dict = self.linear_embedding.state_dict() + transformer_state_dict = self.transformer.state_dict() + rot_state_dict = self.rotatory_attention.state_dict() + recon_state_dict = self.reconstruct.state_dict() + decoder_state_dict = self.decoder.state_dict() + out_state_dict = self.out.state_dict() + + for name, param in self.trained_model.state_dict().items(): + if name.startswith('dense'): + dense_state_dict[name[len("dense."):]].copy_(param) + elif name.startswith('linear_embedding'): + embedding_state_dict[name[len("linear_embedding."):]].copy_(param) + elif name.startswith('transformer'): + transformer_state_dict[name[len('transformer.'):]].copy_(param) + elif name.startswith('rotatory_attention'): + rot_state_dict[name[len('rotatory_attention.'):]].copy_(param) + elif name.startswith('reconstruct'): + recon_state_dict[name[len('reconstruct.'):]].copy_(param) + elif name.startswith('decoder'): + decoder_state_dict[name[len('decoder.'):]].copy_(param) + elif name.startswith('out'): + out_state_dict[name[len('out.'):]].copy_(param) + + self.dense.eval() + self.linear_embedding.eval() + self.transformer.eval() + self.rotatory_attention.eval() + self.rotatory_attention.eval() + self.reconstruct.eval() + self.decoder.eval() + self.out.eval() + self.gradients = [] + + def activations_hook(self, grad): + self.gradients.append(grad) + + def get_activations_gradient(self): + return self.gradients + + def clear_activations_gradient(self): + self.gradients.clear() + + def get_activations(self, x): + x1, x2, x3, x4 = self.dense(x) + z1, z2, z3 = self.linear_embedding(x1, x2, x3) + e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3) + r1, r2, r3 = self.rotatory_attention(z1, z2, z3) + + f1 = e1 + r1 + f2 = e2 + r2 + f3 = e3 + r3 + + o1, o2, o3 = self.reconstruct(f1, f2, f3) + y = self.decoder(o1, o2, o3, x4) + y.register_hook(self.activations_hook) + return self.out(y), y + + def forward(self, x): + x1, x2, x3, x4 = self.dense(x) + z1, z2, z3 = self.linear_embedding(x1, x2, x3) + e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3) + r1, r2, r3 = self.rotatory_attention(z1, z2, z3) + + f1 = e1 + r1 + f2 = e2 + r2 + f3 = e3 + r3 + + o1, o2, o3 = self.reconstruct(f1, f2, f3) + y = self.decoder(o1, o2, o3, x4) + + return self.out(y) + + +def seg_gradcam(model, image_path, index_list: list, instance_list: list, colormap=cv2.COLORMAP_JET, img_size=(512,512)): + + assert len(index_list) == len(instance_list), print( + "Length of list of indices is not equal to length of list of instances") + + name_dict = { + 0: "background", + 1: "left_ventricle", + 2: "right_ventricle", + 3: "left_atrium", + 4: "right_atrium", + 5: "myocardium", + 6: "descending_aeorta", + 7: "pulmonary_trunk", + 8: "ascending_aorta", + 9: "vena_cava", + 10: "auricle", + 11: "coronary_artery", + } + + num_slice = len(index_list) + image = sitk.GetArrayFromImage(sitk.ReadImage(image_path, sitk.sitkFloat32)) + + output = np.ones((num_slice, 512, 512), dtype=np.float32) + slice_list = np.array([image[index] for index in index_list]) + + fig, axes = plt.subplots(nrows=2, ncols=num_slice, figsize=(16,16)) + for i in range(num_slice): output[i,:,:] = slice_list[i] + + input = torch.from_numpy(output).unsqueeze(1).cuda() + output, activations = model.get_activations(input) + activations = activations.detach() + + # calculate score + for x in range(num_slice): + class_output = output[x, instance_list[x]] + class_score_sum = class_output.sum() + class_score_sum.backward(retain_graph=True) + + gradients = model.get_activations_gradient() + print(f"Length gradients: {len(gradients)}") + gradients = gradients[-1] + model.clear_activations_gradient() + pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) + + print(f"Pooled Gradient: {pooled_gradients.shape}") + + instance_activation = activations[x] + + for channel in range(64): + instance_activation[channel, :, :] *= pooled_gradients[channel] + + print(f"Activations: {instance_activation.shape}") + + heatmap = torch.mean(instance_activation, dim=0) + print(f"Heatmap shape: {heatmap.shape}") + + heatmap /= torch.max(heatmap) + heatmap = heatmap.detach().cpu().numpy() + heatmap = np.maximum(heatmap, 0) + + heatmap = cv2.resize(heatmap, img_size) + heatmap = np.uint8(255 * heatmap) + heatmap = cv2.applyColorMap(heatmap, colormap) + + image0 = torch.squeeze(output[x, :, :, :]) + image0 = image0.detach().cpu().numpy() + image0 = np.stack((image0,) * 3, axis=-1) + image0 = cv2.cvtColor(image0, cv2.COLOR_GRAY2BGR) + + superimposed_img = (heatmap / 255.0) * 0.6 + image0 + + heatmap_plot = axes[0][x].imshow(superimposed_img, cmap='RdBu') + axes[0][x].set_xlabel( + f"Slice: {index_list[x]} || Class: {name_dict[instance_list[x]]}") + fig.colorbar(heatmap_plot, ax=axes[0][x], fraction=0.046) + + superimposed_plot = axes[1][x].imshow(heatmap, cmap='RdBu') + axes[1][x].set_xlabel(f"Heatmap of slice: {index_list[x]}") + + fig.colorbar(superimposed_plot, + ax=axes[1][x], fraction=0.046) + + title_text = "" + for i, x in enumerate(index_list): + title_text += str(x) + + if i == len(index_list) - 1: + pass + else: + title_text += ", " + + plt.suptitle( + f"Model's focus on slices number: {title_text}", fontsize=16) + plt.subplots_adjust(wspace=0.5) + plt.show() + + + +if __name__ == "__main__": + image_path = 'data/VHSCDD_512/test_images/0001.nii.gz' + model = RotModel() + + num_slice = 4 + class_instance = 6 + index_list = [125, 156, 153, 180] + instance_list = [3, 5, 11, 6] + + seg_gradcam ( + model=model, + image_path=image_path, + index_list=index_list, + instance_list=instance_list + ) \ No newline at end of file