Diff of /gcam.py [000000] .. [70e190]

Switch to side-by-side view

--- a
+++ b/gcam.py
@@ -0,0 +1,219 @@
+import torch
+import torch.nn as nn
+from networks.RotCAtt_TransUNet_plusplus_gradcam import RotCAtt_TransUNet_plusplus_GradCam
+import SimpleITK as sitk
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
+from networks.dense_feature_extraction import Dense
+from networks.linear_embedding import LinearEmbedding
+from networks.transformer import Transformer
+from networks.rotatory_attention import RotatoryAttention
+from networks.recon import Reconstruction
+from networks.uct_decoder import UCTDecoder
+from networks.config import get_config
+
+class RotModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.model_path = 'outputs/RotCAtt_TransUNet_plusplus/VHSCDD_RotCAtt_TransUNet_plusplus_bs6_ps16_epo600_hw512_ly4/model.pth'
+        self.trained_model = torch.load(self.model_path)
+        self.config = get_config() 
+        self.dense = Dense(self.config).cuda()
+        self.linear_embedding = LinearEmbedding(self.config).cuda()
+        self.transformer = Transformer(self.config).cuda()
+        self.rotatory_attention = RotatoryAttention(self.config).cuda()
+        self.reconstruct = Reconstruction(self.config).cuda()
+        self.decoder = UCTDecoder(self.config).cuda()
+        self.out = nn.Conv2d(self.config.df[0], self.config.num_classes, kernel_size=(1,1), stride=(1,1)).cuda()
+
+        # define state dict
+        dense_state_dict = self.dense.state_dict()
+        embedding_state_dict = self.linear_embedding.state_dict()
+        transformer_state_dict = self.transformer.state_dict()
+        rot_state_dict = self.rotatory_attention.state_dict()
+        recon_state_dict = self.reconstruct.state_dict()
+        decoder_state_dict = self.decoder.state_dict()
+        out_state_dict = self.out.state_dict()  
+
+        for name, param in self.trained_model.state_dict().items():
+            if name.startswith('dense'):
+                dense_state_dict[name[len("dense."):]].copy_(param)
+            elif name.startswith('linear_embedding'):
+                embedding_state_dict[name[len("linear_embedding."):]].copy_(param)
+            elif name.startswith('transformer'):
+                transformer_state_dict[name[len('transformer.'):]].copy_(param)
+            elif name.startswith('rotatory_attention'):
+                rot_state_dict[name[len('rotatory_attention.'):]].copy_(param)
+            elif name.startswith('reconstruct'):
+                recon_state_dict[name[len('reconstruct.'):]].copy_(param)
+            elif name.startswith('decoder'):
+                decoder_state_dict[name[len('decoder.'):]].copy_(param)
+            elif name.startswith('out'):
+                out_state_dict[name[len('out.'):]].copy_(param)
+
+        self.dense.eval()
+        self.linear_embedding.eval()
+        self.transformer.eval()
+        self.rotatory_attention.eval()
+        self.rotatory_attention.eval()
+        self.reconstruct.eval()
+        self.decoder.eval()
+        self.out.eval()
+        self.gradients = []
+        
+    def activations_hook(self, grad):
+        self.gradients.append(grad)
+
+    def get_activations_gradient(self):
+        return self.gradients
+
+    def clear_activations_gradient(self):
+        self.gradients.clear() 
+        
+    def get_activations(self, x):
+        x1, x2, x3, x4 = self.dense(x)
+        z1, z2, z3 = self.linear_embedding(x1, x2, x3)
+        e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3)
+        r1, r2, r3 = self.rotatory_attention(z1, z2, z3)
+
+        f1 = e1 + r1
+        f2 = e2 + r2
+        f3 = e3 + r3
+
+        o1, o2, o3 = self.reconstruct(f1, f2, f3)
+        y = self.decoder(o1, o2, o3, x4)
+        y.register_hook(self.activations_hook)
+        return self.out(y), y
+        
+    def forward(self, x):
+        x1, x2, x3, x4 = self.dense(x)
+        z1, z2, z3 = self.linear_embedding(x1, x2, x3)
+        e1, e2, e3, a1_weights, a2_weights, a3_weights = self.transformer(z1, z2, z3)
+        r1, r2, r3 = self.rotatory_attention(z1, z2, z3)
+
+        f1 = e1 + r1
+        f2 = e2 + r2
+        f3 = e3 + r3
+        
+        o1, o2, o3 = self.reconstruct(f1, f2, f3)
+        y = self.decoder(o1, o2, o3, x4)
+        
+        return self.out(y)
+    
+    
+def seg_gradcam(model, image_path, index_list: list, instance_list: list, colormap=cv2.COLORMAP_JET, img_size=(512,512)):
+    
+    assert len(index_list) == len(instance_list), print(
+        "Length of list of indices is not equal to length of list of instances")
+    
+    name_dict = {
+        0: "background",
+        1: "left_ventricle",
+        2: "right_ventricle",
+        3: "left_atrium",
+        4: "right_atrium",
+        5:  "myocardium",
+        6: "descending_aeorta",
+        7: "pulmonary_trunk",
+        8: "ascending_aorta",
+        9: "vena_cava",
+        10: "auricle",
+        11: "coronary_artery",
+    }
+    
+    num_slice = len(index_list)
+    image = sitk.GetArrayFromImage(sitk.ReadImage(image_path, sitk.sitkFloat32))
+    
+    output = np.ones((num_slice, 512, 512), dtype=np.float32)
+    slice_list = np.array([image[index] for index in index_list])
+    
+    fig, axes = plt.subplots(nrows=2, ncols=num_slice, figsize=(16,16))
+    for i in range(num_slice): output[i,:,:] = slice_list[i]
+    
+    input = torch.from_numpy(output).unsqueeze(1).cuda()
+    output, activations = model.get_activations(input)
+    activations = activations.detach()
+    
+    # calculate score
+    for x in range(num_slice):
+        class_output = output[x, instance_list[x]]
+        class_score_sum = class_output.sum()
+        class_score_sum.backward(retain_graph=True)
+
+        gradients = model.get_activations_gradient()
+        print(f"Length gradients: {len(gradients)}")
+        gradients = gradients[-1]
+        model.clear_activations_gradient()
+        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
+
+        print(f"Pooled Gradient: {pooled_gradients.shape}")
+
+        instance_activation = activations[x]
+
+        for channel in range(64):
+            instance_activation[channel, :, :] *= pooled_gradients[channel]
+
+        print(f"Activations: {instance_activation.shape}")
+
+        heatmap = torch.mean(instance_activation, dim=0)
+        print(f"Heatmap shape: {heatmap.shape}")
+
+        heatmap /= torch.max(heatmap)
+        heatmap = heatmap.detach().cpu().numpy()
+        heatmap = np.maximum(heatmap, 0)
+
+        heatmap = cv2.resize(heatmap, img_size)
+        heatmap = np.uint8(255 * heatmap)
+        heatmap = cv2.applyColorMap(heatmap, colormap)
+
+        image0 = torch.squeeze(output[x, :, :, :])
+        image0 = image0.detach().cpu().numpy()
+        image0 = np.stack((image0,) * 3, axis=-1)
+        image0 = cv2.cvtColor(image0, cv2.COLOR_GRAY2BGR)
+
+        superimposed_img = (heatmap / 255.0) * 0.6 + image0
+
+        heatmap_plot = axes[0][x].imshow(superimposed_img, cmap='RdBu')
+        axes[0][x].set_xlabel(
+            f"Slice: {index_list[x]} || Class: {name_dict[instance_list[x]]}")
+        fig.colorbar(heatmap_plot, ax=axes[0][x], fraction=0.046)
+
+        superimposed_plot = axes[1][x].imshow(heatmap, cmap='RdBu')
+        axes[1][x].set_xlabel(f"Heatmap of slice: {index_list[x]}")
+
+        fig.colorbar(superimposed_plot,
+                     ax=axes[1][x], fraction=0.046)
+
+    title_text = ""
+    for i, x in enumerate(index_list):
+        title_text += str(x)
+
+        if i == len(index_list) - 1:
+            pass
+        else:
+            title_text += ", "
+
+    plt.suptitle(
+        f"Model's focus on slices number: {title_text}", fontsize=16)
+    plt.subplots_adjust(wspace=0.5)
+    plt.show()
+    
+
+    
+if __name__ == "__main__":
+    image_path = 'data/VHSCDD_512/test_images/0001.nii.gz'
+    model = RotModel()
+    
+    num_slice      = 4
+    class_instance = 6
+    index_list     = [125, 156, 153,  180]
+    instance_list  = [3, 5, 11, 6]
+
+    seg_gradcam (
+        model=model, 
+        image_path=image_path, 
+        index_list=index_list, 
+        instance_list=instance_list
+    )
\ No newline at end of file