Diff of /models.py [000000] .. [d5c425]

Switch to side-by-side view

--- a
+++ b/models.py
@@ -0,0 +1,534 @@
+import torch
+from torch import nn
+import math
+from typing import Union, Sequence, Tuple
+from utils import define_act_layer
+from swintransformer import SwinTransformer, look_up_option
+from einops import rearrange
+import os 
+from medpy.io import load, header
+import numpy as np
+
+
+
+class MultiTaskModel(nn.Module):
+    def __init__(self, task, in_features, hidden_units=None, act_layer=nn.ReLU(), dropout=0.7) -> None:
+
+        super().__init__()
+        self.act = act_layer
+        incoming_features = in_features
+        hidden_layer_list = []
+        self.task = task
+        for hidden_unit in hidden_units:
+            hidden_block = nn.Sequential(
+                nn.Linear(incoming_features, hidden_unit),
+                nn.LeakyReLU(0.1, inplace=True),
+                nn.BatchNorm1d(hidden_unit),
+                nn.Dropout(dropout),
+            )
+            hidden_layer_list.append(hidden_block)
+            incoming_features = hidden_unit
+        self.hidden_layer = nn.Sequential(*hidden_layer_list)
+        out_features = 2 if self.task=="multitask" else 1
+        self.classifier = nn.Linear(hidden_units[-1], out_features)
+        # self.output_act = nn.Sigmoid()
+        # self.output_act1 = nn.LeakyReLU()
+        
+
+    def forward(self, x):
+        x = self.hidden_layer(x)
+        classifier = self.classifier(x)
+        # print(classifier)
+        if self.task =="multitask":
+            grade, hazard = classifier[0], classifier[1]
+            return grade, hazard
+        else:
+            # print(self.output_act(classifier))
+            # return self.output_act(classifier)
+            return classifier
+
+
+class SelfAttentionBi(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super(SelfAttentionBi, self).__init__()
+
+        self.WQ = nn.Linear(dim_in, dim_out)
+        self.WK = nn.Linear(dim_in, dim_out)
+        self.WV = nn.Linear(dim_in, dim_out)
+        self.root = math.sqrt(dim_out)
+        self.softmax = nn.Softmax(dim=1)
+
+    def forward(self, mod1, mod2):
+        x = torch.stack((mod1, mod2), dim=1)
+        Q = self.WQ(x)
+        K = self.WK(x)
+        V = self.WV(x)
+
+        QK = torch.bmm(Q, K.transpose(1, 2))
+        attention_matrix = self.softmax(QK/self.root)
+        out = torch.bmm(attention_matrix, V)
+        return out
+
+
+class FusionModelBi(nn.Module):
+    def __init__(self, args, dim_in, dim_out):
+        super(FusionModelBi, self).__init__()
+        self.fusion_type = args.fusion_type
+        act_layer = define_act_layer(args.act_type)
+
+        if self.fusion_type == "attention":
+            self.attention_module = SelfAttentionBi(dim_in, dim_out)
+            self.taskmodel = MultiTaskModel(
+                args.task, dim_out*2, args.hidden_units, act_layer, args.dropout)
+        elif self.fusion_type == "fused_attention":
+            self.attention_module = SelfAttentionBi(dim_in, dim_out)
+            self.taskmodel = MultiTaskModel(
+                args.task, (dim_out+1)**2, args.hidden_units, act_layer, args.dropout)
+        elif self.fusion_type == "kronecker":
+            self.taskmodel = MultiTaskModel(
+                args.task, (dim_in+1)**2, args.hidden_units, act_layer, args.dropout)
+        elif self.fusion_type == "concatenation":
+            self.taskmodel = MultiTaskModel(
+                args.task, dim_in*2, args.hidden_units, act_layer, args.dropout)
+        else:
+            raise NotImplementedError(
+                f'Fusion method {self.fusion_type} is not implemented')
+
+    def forward(self, vec1, vec2):
+
+        if self.fusion_type == "attention":
+            x = self.attention_module(vec1, vec2)
+            x = x.view(x.shape[0], x.shape[1]*x.shape[2])
+
+        elif self.fusion_type == "kronecker":
+            vec1 = torch.cat(
+                (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1)
+            vec2 = torch.cat(
+                (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1)
+            x = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten(
+                start_dim=1)
+
+        elif self.fusion_type == "fused_attention":
+            vec1, vec2 = self.attention_module(
+                vec1, vec2)[:, 0, :], self.attention_module(vec1, vec2)[:, 1, :]
+            vec1 = torch.cat(
+                (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1)
+            vec2 = torch.cat(
+                (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1)
+            x = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten(
+                start_dim=1)
+            print(x.shape)
+
+        elif self.fusion_type == "concatenation":
+            x = torch.cat((vec1, vec2), dim=1)
+
+        else:
+            raise NotImplementedError(
+                f'Fusion method {self.fusion_type} is not implemented')
+        return self.taskmodel(x)
+    
+class SelfAttention(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super(SelfAttention, self).__init__()
+
+        self.WQ = nn.Linear(dim_in, dim_out)
+        self.WK = nn.Linear(dim_in, dim_out)
+        self.WV = nn.Linear(dim_in, dim_out)
+        self.root = math.sqrt(dim_out)
+        self.softmax = nn.Softmax(dim=1)
+
+    def forward(self, mod1, mod2, mod3):
+        x = torch.stack((mod1, mod2 ,mod3), dim=1)
+        Q = self.WQ(x) 
+        K = self.WK(x) 
+        V = self.WV(x) 
+
+        QK = torch.bmm(Q, K.transpose(1, 2)) 
+        attention_matrix = self.softmax(QK/self.root)
+        out = torch.bmm(attention_matrix, V)
+        return out
+
+    
+class FusionModel(nn.Module):
+    def __init__(self, args, dim_in, dim_out):
+        super(FusionModel, self).__init__()
+        self.fusion_type = args.fusion_type
+        act_layer = define_act_layer(args.act_type)
+    
+        if self.fusion_type == "attention":
+            self.attention_module = SelfAttention(dim_in, dim_out)
+            self.taskmodel = MultiTaskModel(
+                args.task, dim_out*3, args.hidden_units, act_layer, args.dropout)
+            
+        elif self.fusion_type == "fused_attention":
+            self.attention_module = SelfAttention(dim_in, dim_out)
+            self.taskmodel = MultiTaskModel(
+                args.task, (dim_out+1)**3, args.hidden_units, act_layer, args.dropout)
+            
+        elif self.fusion_type == "kronecker":
+            self.taskmodel = MultiTaskModel(
+                args.task, (dim_in+1)**3, args.hidden_units, act_layer, args.dropout)
+            
+        elif self.fusion_type == "concatenation":
+            self.taskmodel = MultiTaskModel(
+                args.task, dim_in*3, args.hidden_units, act_layer, args.dropout)
+            
+        else:
+            raise NotImplementedError(
+                f'Fusion method {self.fusion_type} is not implemented')
+        
+    def forward(self, vec1, vec2, vec3):
+        
+        if self.fusion_type == "attention":
+            x = self.attention_module(vec1, vec2, vec3)
+            x = x.view(x.shape[0], x.shape[1]*x.shape[2])
+            
+        elif self.fusion_type == "kronecker":
+            vec1 = torch.cat(
+                (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1)
+            vec2 = torch.cat(
+                (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1)
+            vec3 = torch.cat(
+                (vec3, torch.ones((vec3.shape[0], 1)).to(vec3.device)), 1)
+            x12 = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten(
+            start_dim=1)
+            x = torch.bmm(x12.unsqueeze(2), vec3.unsqueeze(1)).flatten(
+                start_dim=1)
+            
+        elif self.fusion_type == "fused_attention":
+            vec1, vec2, vec3 = self.attention_module(
+                vec1, vec2, vec3)[:, 0, :], self.attention_module(vec1, vec2, vec3)[:, 1, :] , self.attention_module(vec1, vec2, vec3)[:, 2, :]
+            vec1 = torch.cat(
+                (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1)
+            vec2 = torch.cat(
+                (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1)
+            vec3 = torch.cat(
+                (vec3, torch.ones((vec3.shape[0], 1)).to(vec3.device)), 1)
+            x12 = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten(
+                start_dim=1)
+            x = torch.bmm(x12.unsqueeze(2), vec3.unsqueeze(1)).flatten(
+                start_dim=1)
+            
+        elif self.fusion_type == "concatenation":
+            x = torch.cat((vec1, vec2, vec3), dim=1)
+
+        else: 
+            raise NotImplementedError(
+                f'Fusion method {self.fusion_type} is not implemented')
+        return self.taskmodel(x)
+
+    
+
+    
+class MultiHeadSelfAttention(nn.Module):
+    def __init__(self, emb_size, num_heads):
+        super().__init__()
+        self.emb_size = emb_size
+        self.num_heads = num_heads
+        self.qkv = nn.Linear(emb_size, emb_size * 3)
+        self.attention = nn.MultiheadAttention(emb_size, num_heads)
+
+    def forward(self, x):
+        qkv = self.qkv(x).chunk(3, dim=-1)
+        q, k, v = [rearrange(tensor, "b n (h d) -> b n h d", h=self.num_heads) for tensor in qkv]
+        q = rearrange(q, "b n h d -> (b h) n d")
+        k = rearrange(k, "b n h d -> (b h) n d")
+        v = rearrange(v, "b n h d -> (b h) n d")
+        attn_output, _ = self.attention(q, k, v)
+        attn_output = rearrange(attn_output, "(b h) n d -> b n (h d)", h=self.num_heads)
+        return attn_output
+
+class PatchEmbedding(nn.Module):
+    def __init__(self, in_channels, patch_size, emb_size):
+        super().__init__()
+        self.patch_size = patch_size
+        self.emb_size = emb_size
+        self.projection = nn.Conv3d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
+        self.flatten = nn.Flatten(2)
+        self.linear = nn.Linear(emb_size, emb_size)
+
+    def forward(self, x):
+        x = self.projection(x)  # (B, emb_size, D/P, H/P, W/P)
+        x = self.flatten(x)
+        return self.linear(x.transpose(-1, -2))  # (B, num_patches, emb_size)
+
+class TransformerBlock(nn.Module):
+    def __init__(self, emb_size, num_heads, mlp_ratio=4.0, dropout_rate=0.1):
+        super().__init__()
+        self.norm1 = nn.LayerNorm(emb_size)
+        self.attn = nn.MultiheadAttention(emb_size, num_heads)
+        self.norm2 = nn.LayerNorm(emb_size)
+        self.ff = nn.Sequential(
+            nn.Linear(emb_size, int(emb_size * mlp_ratio)),
+            nn.GELU(),
+            nn.Dropout(dropout_rate),
+            nn.Linear(int(emb_size * mlp_ratio), emb_size),
+            nn.Dropout(dropout_rate),
+        )
+
+    def forward(self, x):
+        # x is expected to be of shape (num_patches, batch_size, emb_size)
+        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
+        x = x + self.ff(self.norm2(x))
+        return x
+
+
+class VisionTransformer(nn.Module):
+    def __init__(self, in_channels, patch_size, emb_size, depth, num_heads, mlp_ratio=4.0, dropout_rate=0.1):
+        super().__init__()
+        self.embedding = PatchEmbedding(in_channels, patch_size, emb_size)
+        self.blocks = nn.ModuleList([
+            TransformerBlock(emb_size, num_heads, mlp_ratio, dropout_rate) for _ in range(depth)
+        ])
+        self.norm = nn.LayerNorm(emb_size)
+        self.pool = nn.AdaptiveAvgPool1d(1)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        for block in self.blocks:
+            x = block(x)
+        x = self.norm(x)
+        x = self.pool(x.transpose(-1, -2)).squeeze(-1)
+        return x
+    
+# class SwinTransformerRadiologyModel(nn.Module):
+
+#     def __init__(
+#         self,
+#         patch_size: Union[Sequence[int], int],
+#         window_size: Union[Sequence[int], int],
+#         in_channels: int,
+#         out_channels: int,
+#         depths: Sequence[int] = (2, 2, 2, 2),
+#         num_heads: Sequence[int] = (3, 6, 12, 24),
+#         #try different feature_size!!
+#         feature_size: int = 24,
+#         norm_name: Union[Tuple, str] = "instance",
+#         drop_rate: float = 0.7,
+#         attn_drop_rate: float = 0.,
+#         dropout_path_rate: float = 0.0,
+#         normalize: bool = True,
+#         use_checkpoint: bool = False,
+#         spatial_dims: int = 3,
+#         downsample="merging",
+#     ) -> None:
+#         """
+#         Input requirement : [BxCxDxHxW]
+#         Args:
+#             in_channels: dimension of input channels.
+#             out_channels: dimension of output channels.
+#             feature_size: dimension of network feature size.
+#             depths: number of layers in each stage.
+#             num_heads: number of attention heads.
+#             norm_name: feature normalization type and arguments.
+#             drop_rate: dropout rate.
+#             attn_drop_rate: attention dropout rate.
+#             dropout_path_rate: drop path rate.
+#             normalize: normalize output intermediate features in each stage.
+#             use_checkpoint: use gradient checkpointing for reduced memory usage.
+#             spatial_dims: number of spatial dims.
+#             downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
+#                 user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
+#                 The default is currently `"merging"` (the original version defined in v0.9.0).
+
+
+#         """
+
+#         super().__init__()
+
+#         if not (spatial_dims == 2 or spatial_dims == 3):
+#             raise ValueError("spatial dimension should be 2 or 3.")
+
+#         self.normalize = normalize
+
+#         self.swinViT = SwinTransformer(
+#             in_chans=in_channels,
+#             embed_dim=feature_size,
+#             window_size=window_size,
+#             patch_size=patch_size,
+#             depths=depths,
+#             num_heads=num_heads,
+#             mlp_ratio=4.0,
+#             qkv_bias=True,
+#             drop_rate=drop_rate,
+#             attn_drop_rate=attn_drop_rate,
+#             drop_path_rate=dropout_path_rate,
+#             norm_layer=nn.LayerNorm,
+#             use_checkpoint=use_checkpoint,
+#             spatial_dims=spatial_dims,
+#             downsample=look_up_option(downsample) if isinstance(
+#                 downsample, str) else downsample,
+#         )
+#         self.norm = nn.LayerNorm(feature_size*16)
+#         self.avgpool = nn.AdaptiveAvgPool3d([1, 1, 1])
+#         self.dim_reduction = nn.Conv3d(feature_size*16, out_channels, 1)
+        
+#     def forward(self, x_in):
+#         hidden_states_out = self.swinViT(x_in, self.normalize)
+#         hidden_output = rearrange(
+#             hidden_states_out[4], "b c d h w -> b d h w c")
+#         nomalized_hidden_states_out = self.norm(hidden_output)
+#         nomalized_hidden_states_out = rearrange(
+#             nomalized_hidden_states_out, "b d h w c -> b c d h w")
+#         output = self.avgpool(nomalized_hidden_states_out)
+#         output = torch.flatten(self.dim_reduction(output), 1)
+#         # print(output.shape)
+
+#         return output
+
+
+
+# class CNNRadiologyModel(nn.Module):
+#     def __init__(
+#         self,
+#         in_channels: int,
+#         out_channels: int,
+#         feature_size: int = 24,
+#         spatial_dims: int = 3,
+#         dropout_rate: float = 0.7,
+#     ) -> None:
+#         """
+#         A simple 3D CNN-based feature extractor.
+#         Input requirement : [BxCxDxHxW]
+#         Args:
+#             in_channels: dimension of input channels.
+#             out_channels: dimension of output channels.
+#             feature_size: dimension of network feature size.
+#             spatial_dims: number of spatial dims (2 or 3).
+#             dropout_rate: dropout rate.
+#         """
+#         super().__init__()
+
+#         if spatial_dims != 3:
+#             raise ValueError("This implementation is designed for 3D inputs.")
+
+#         self.conv1 = nn.Conv3d(in_channels, feature_size, kernel_size=3, padding=1)
+#         self.conv2 = nn.Conv3d(feature_size, feature_size * 2, kernel_size=3, padding=1)
+#         self.conv3 = nn.Conv3d(feature_size * 2, feature_size * 4, kernel_size=3, padding=1)
+#         self.conv4 = nn.Conv3d(feature_size * 4, feature_size * 8, kernel_size=3, padding=1)
+
+#         self.bn1 = nn.BatchNorm3d(feature_size)
+#         self.bn2 = nn.BatchNorm3d(feature_size * 2)
+#         self.bn3 = nn.BatchNorm3d(feature_size * 4)
+#         self.bn4 = nn.BatchNorm3d(feature_size * 8)
+
+#         self.relu = nn.ReLU(inplace=True)
+#         self.dropout = nn.Dropout(dropout_rate)
+#         self.pool = nn.AdaptiveAvgPool3d(1)
+#         self.fc = nn.Conv3d(feature_size * 8, out_channels, kernel_size=1)
+
+#     def forward(self, x):
+#         x = self.relu(self.bn1(self.conv1(x)))
+#         x = self.relu(self.bn2(self.conv2(x)))
+#         x = self.relu(self.bn3(self.conv3(x)))
+#         x = self.relu(self.bn4(self.conv4(x)))
+#         x = self.dropout(x)
+#         x = self.pool(x)
+#         x = torch.flatten(self.fc(x), 1)
+#         #print(x.shape)
+#         return x
+
+class VisionTransformerRadiologyModel(nn.Module):
+    def __init__(
+        self,
+        patch_size: Union[Sequence[int], int],
+        in_channels: int,
+        out_channels: int,
+        emb_size: int,
+        depth: int = 12,
+        num_heads: int = 12,
+        mlp_ratio: float = 4.0,
+        dropout_rate: float = 0.1,
+    ) -> None:
+        super().__init__()
+
+        self.vit = VisionTransformer(
+            in_channels=in_channels,
+            patch_size=patch_size,
+            emb_size=emb_size,
+            depth=depth,
+            num_heads=num_heads,
+            mlp_ratio=mlp_ratio,
+            dropout_rate=dropout_rate
+        )
+        self.fc = nn.Linear(emb_size, out_channels)
+
+    def forward(self, x):
+        x = self.vit(x)
+        x = self.fc(x)
+        return x
+
+class Model(nn.Module):
+    def __init__(self, args):
+        super(Model, self).__init__()
+        # self.extractor_ct_tumor = SwinTransformerRadiologyModel(
+        #     patch_size=(1, 2, 2),
+        #     window_size=[[4, 4, 4], [4, 4, 4], [8, 8, 8], [4, 4, 4]],
+        #     in_channels=4,
+        #     out_channels=args.feature_size,
+        #     depths=(2, 2, 2, 2),
+        #     num_heads=(3, 6, 12, 24),
+        #     feature_size=int(args.feature_size/2),
+        #     norm_name="instance",
+        #     drop_rate=0.7,
+        #     attn_drop_rate=0.,
+        #     dropout_path_rate=0.2,
+        #     normalize=True,
+        #     use_checkpoint=False,
+        #     spatial_dims=3
+        # )
+        # self.extractor_ct_lymph = SwinTransformerRadiologyModel(
+        #     patch_size=(1, 2, 2),
+        #     window_size=[[4, 4, 4], [4, 4, 4], [8, 8, 8], [4, 4, 4]],
+        #     in_channels=4,
+        #     out_channels=args.feature_size,
+        #     depths=(2, 2, 2, 2),
+        #     num_heads=(3, 6, 12, 24),
+        #     feature_size=int(args.feature_size/2),
+        #     norm_name="instance",
+        #     drop_rate=0.7,
+        #     attn_drop_rate=0.,
+        #     dropout_path_rate=0.2,
+        #     normalize=True,
+        #     use_checkpoint=False,
+        #     spatial_dims=3
+        # )
+        # self.extractor_ct_tumor = CNNRadiologyModel(
+        #     in_channels=4, 
+        #     out_channels=args.feature_size
+        # )
+        # self.extractor_ct_lymph = CNNRadiologyModel(
+        #     in_channels=4, 
+        #     out_channels=args.feature_size
+        # )
+
+        self.extractor_ct_tumor = VisionTransformerRadiologyModel(
+            patch_size=(2, 2, 2),
+            in_channels=4,
+            out_channels=args.feature_size,
+            emb_size=int(args.feature_size/2),
+            depth=12,
+            num_heads=12,
+            mlp_ratio=4.0,
+            dropout_rate=0.1
+        )
+        self.extractor_ct_lymph = VisionTransformerRadiologyModel(
+            patch_size=(2, 2, 2),
+            in_channels=4,
+            out_channels=args.feature_size,
+            emb_size=int(args.feature_size/2),
+            depth=12,
+            num_heads=12,
+            mlp_ratio=4.0,
+            dropout_rate=0.1
+        )
+
+        self.fusion = FusionModelBi(args, args.feature_size, args.dim_out)
+        
+    def forward(self, ct_tumor, ct_lymph):
+        features_tumor = self.extractor_ct_tumor(ct_tumor)
+        features_lymph = self.extractor_ct_lymph(ct_lymph)
+        
+        output = self.fusion(features_tumor, features_lymph)
+        return output
+