--- a +++ b/models.py @@ -0,0 +1,534 @@ +import torch +from torch import nn +import math +from typing import Union, Sequence, Tuple +from utils import define_act_layer +from swintransformer import SwinTransformer, look_up_option +from einops import rearrange +import os +from medpy.io import load, header +import numpy as np + + + +class MultiTaskModel(nn.Module): + def __init__(self, task, in_features, hidden_units=None, act_layer=nn.ReLU(), dropout=0.7) -> None: + + super().__init__() + self.act = act_layer + incoming_features = in_features + hidden_layer_list = [] + self.task = task + for hidden_unit in hidden_units: + hidden_block = nn.Sequential( + nn.Linear(incoming_features, hidden_unit), + nn.LeakyReLU(0.1, inplace=True), + nn.BatchNorm1d(hidden_unit), + nn.Dropout(dropout), + ) + hidden_layer_list.append(hidden_block) + incoming_features = hidden_unit + self.hidden_layer = nn.Sequential(*hidden_layer_list) + out_features = 2 if self.task=="multitask" else 1 + self.classifier = nn.Linear(hidden_units[-1], out_features) + # self.output_act = nn.Sigmoid() + # self.output_act1 = nn.LeakyReLU() + + + def forward(self, x): + x = self.hidden_layer(x) + classifier = self.classifier(x) + # print(classifier) + if self.task =="multitask": + grade, hazard = classifier[0], classifier[1] + return grade, hazard + else: + # print(self.output_act(classifier)) + # return self.output_act(classifier) + return classifier + + +class SelfAttentionBi(nn.Module): + def __init__(self, dim_in, dim_out): + super(SelfAttentionBi, self).__init__() + + self.WQ = nn.Linear(dim_in, dim_out) + self.WK = nn.Linear(dim_in, dim_out) + self.WV = nn.Linear(dim_in, dim_out) + self.root = math.sqrt(dim_out) + self.softmax = nn.Softmax(dim=1) + + def forward(self, mod1, mod2): + x = torch.stack((mod1, mod2), dim=1) + Q = self.WQ(x) + K = self.WK(x) + V = self.WV(x) + + QK = torch.bmm(Q, K.transpose(1, 2)) + attention_matrix = self.softmax(QK/self.root) + out = torch.bmm(attention_matrix, V) + return out + + +class FusionModelBi(nn.Module): + def __init__(self, args, dim_in, dim_out): + super(FusionModelBi, self).__init__() + self.fusion_type = args.fusion_type + act_layer = define_act_layer(args.act_type) + + if self.fusion_type == "attention": + self.attention_module = SelfAttentionBi(dim_in, dim_out) + self.taskmodel = MultiTaskModel( + args.task, dim_out*2, args.hidden_units, act_layer, args.dropout) + elif self.fusion_type == "fused_attention": + self.attention_module = SelfAttentionBi(dim_in, dim_out) + self.taskmodel = MultiTaskModel( + args.task, (dim_out+1)**2, args.hidden_units, act_layer, args.dropout) + elif self.fusion_type == "kronecker": + self.taskmodel = MultiTaskModel( + args.task, (dim_in+1)**2, args.hidden_units, act_layer, args.dropout) + elif self.fusion_type == "concatenation": + self.taskmodel = MultiTaskModel( + args.task, dim_in*2, args.hidden_units, act_layer, args.dropout) + else: + raise NotImplementedError( + f'Fusion method {self.fusion_type} is not implemented') + + def forward(self, vec1, vec2): + + if self.fusion_type == "attention": + x = self.attention_module(vec1, vec2) + x = x.view(x.shape[0], x.shape[1]*x.shape[2]) + + elif self.fusion_type == "kronecker": + vec1 = torch.cat( + (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1) + vec2 = torch.cat( + (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1) + x = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten( + start_dim=1) + + elif self.fusion_type == "fused_attention": + vec1, vec2 = self.attention_module( + vec1, vec2)[:, 0, :], self.attention_module(vec1, vec2)[:, 1, :] + vec1 = torch.cat( + (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1) + vec2 = torch.cat( + (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1) + x = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten( + start_dim=1) + print(x.shape) + + elif self.fusion_type == "concatenation": + x = torch.cat((vec1, vec2), dim=1) + + else: + raise NotImplementedError( + f'Fusion method {self.fusion_type} is not implemented') + return self.taskmodel(x) + +class SelfAttention(nn.Module): + def __init__(self, dim_in, dim_out): + super(SelfAttention, self).__init__() + + self.WQ = nn.Linear(dim_in, dim_out) + self.WK = nn.Linear(dim_in, dim_out) + self.WV = nn.Linear(dim_in, dim_out) + self.root = math.sqrt(dim_out) + self.softmax = nn.Softmax(dim=1) + + def forward(self, mod1, mod2, mod3): + x = torch.stack((mod1, mod2 ,mod3), dim=1) + Q = self.WQ(x) + K = self.WK(x) + V = self.WV(x) + + QK = torch.bmm(Q, K.transpose(1, 2)) + attention_matrix = self.softmax(QK/self.root) + out = torch.bmm(attention_matrix, V) + return out + + +class FusionModel(nn.Module): + def __init__(self, args, dim_in, dim_out): + super(FusionModel, self).__init__() + self.fusion_type = args.fusion_type + act_layer = define_act_layer(args.act_type) + + if self.fusion_type == "attention": + self.attention_module = SelfAttention(dim_in, dim_out) + self.taskmodel = MultiTaskModel( + args.task, dim_out*3, args.hidden_units, act_layer, args.dropout) + + elif self.fusion_type == "fused_attention": + self.attention_module = SelfAttention(dim_in, dim_out) + self.taskmodel = MultiTaskModel( + args.task, (dim_out+1)**3, args.hidden_units, act_layer, args.dropout) + + elif self.fusion_type == "kronecker": + self.taskmodel = MultiTaskModel( + args.task, (dim_in+1)**3, args.hidden_units, act_layer, args.dropout) + + elif self.fusion_type == "concatenation": + self.taskmodel = MultiTaskModel( + args.task, dim_in*3, args.hidden_units, act_layer, args.dropout) + + else: + raise NotImplementedError( + f'Fusion method {self.fusion_type} is not implemented') + + def forward(self, vec1, vec2, vec3): + + if self.fusion_type == "attention": + x = self.attention_module(vec1, vec2, vec3) + x = x.view(x.shape[0], x.shape[1]*x.shape[2]) + + elif self.fusion_type == "kronecker": + vec1 = torch.cat( + (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1) + vec2 = torch.cat( + (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1) + vec3 = torch.cat( + (vec3, torch.ones((vec3.shape[0], 1)).to(vec3.device)), 1) + x12 = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten( + start_dim=1) + x = torch.bmm(x12.unsqueeze(2), vec3.unsqueeze(1)).flatten( + start_dim=1) + + elif self.fusion_type == "fused_attention": + vec1, vec2, vec3 = self.attention_module( + vec1, vec2, vec3)[:, 0, :], self.attention_module(vec1, vec2, vec3)[:, 1, :] , self.attention_module(vec1, vec2, vec3)[:, 2, :] + vec1 = torch.cat( + (vec1, torch.ones((vec1.shape[0], 1)).to(vec1.device)), 1) + vec2 = torch.cat( + (vec2, torch.ones((vec2.shape[0], 1)).to(vec2.device)), 1) + vec3 = torch.cat( + (vec3, torch.ones((vec3.shape[0], 1)).to(vec3.device)), 1) + x12 = torch.bmm(vec1.unsqueeze(2), vec2.unsqueeze(1)).flatten( + start_dim=1) + x = torch.bmm(x12.unsqueeze(2), vec3.unsqueeze(1)).flatten( + start_dim=1) + + elif self.fusion_type == "concatenation": + x = torch.cat((vec1, vec2, vec3), dim=1) + + else: + raise NotImplementedError( + f'Fusion method {self.fusion_type} is not implemented') + return self.taskmodel(x) + + + + +class MultiHeadSelfAttention(nn.Module): + def __init__(self, emb_size, num_heads): + super().__init__() + self.emb_size = emb_size + self.num_heads = num_heads + self.qkv = nn.Linear(emb_size, emb_size * 3) + self.attention = nn.MultiheadAttention(emb_size, num_heads) + + def forward(self, x): + qkv = self.qkv(x).chunk(3, dim=-1) + q, k, v = [rearrange(tensor, "b n (h d) -> b n h d", h=self.num_heads) for tensor in qkv] + q = rearrange(q, "b n h d -> (b h) n d") + k = rearrange(k, "b n h d -> (b h) n d") + v = rearrange(v, "b n h d -> (b h) n d") + attn_output, _ = self.attention(q, k, v) + attn_output = rearrange(attn_output, "(b h) n d -> b n (h d)", h=self.num_heads) + return attn_output + +class PatchEmbedding(nn.Module): + def __init__(self, in_channels, patch_size, emb_size): + super().__init__() + self.patch_size = patch_size + self.emb_size = emb_size + self.projection = nn.Conv3d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size) + self.flatten = nn.Flatten(2) + self.linear = nn.Linear(emb_size, emb_size) + + def forward(self, x): + x = self.projection(x) # (B, emb_size, D/P, H/P, W/P) + x = self.flatten(x) + return self.linear(x.transpose(-1, -2)) # (B, num_patches, emb_size) + +class TransformerBlock(nn.Module): + def __init__(self, emb_size, num_heads, mlp_ratio=4.0, dropout_rate=0.1): + super().__init__() + self.norm1 = nn.LayerNorm(emb_size) + self.attn = nn.MultiheadAttention(emb_size, num_heads) + self.norm2 = nn.LayerNorm(emb_size) + self.ff = nn.Sequential( + nn.Linear(emb_size, int(emb_size * mlp_ratio)), + nn.GELU(), + nn.Dropout(dropout_rate), + nn.Linear(int(emb_size * mlp_ratio), emb_size), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + # x is expected to be of shape (num_patches, batch_size, emb_size) + x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] + x = x + self.ff(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + def __init__(self, in_channels, patch_size, emb_size, depth, num_heads, mlp_ratio=4.0, dropout_rate=0.1): + super().__init__() + self.embedding = PatchEmbedding(in_channels, patch_size, emb_size) + self.blocks = nn.ModuleList([ + TransformerBlock(emb_size, num_heads, mlp_ratio, dropout_rate) for _ in range(depth) + ]) + self.norm = nn.LayerNorm(emb_size) + self.pool = nn.AdaptiveAvgPool1d(1) + + def forward(self, x): + x = self.embedding(x) + for block in self.blocks: + x = block(x) + x = self.norm(x) + x = self.pool(x.transpose(-1, -2)).squeeze(-1) + return x + +# class SwinTransformerRadiologyModel(nn.Module): + +# def __init__( +# self, +# patch_size: Union[Sequence[int], int], +# window_size: Union[Sequence[int], int], +# in_channels: int, +# out_channels: int, +# depths: Sequence[int] = (2, 2, 2, 2), +# num_heads: Sequence[int] = (3, 6, 12, 24), +# #try different feature_size!! +# feature_size: int = 24, +# norm_name: Union[Tuple, str] = "instance", +# drop_rate: float = 0.7, +# attn_drop_rate: float = 0., +# dropout_path_rate: float = 0.0, +# normalize: bool = True, +# use_checkpoint: bool = False, +# spatial_dims: int = 3, +# downsample="merging", +# ) -> None: +# """ +# Input requirement : [BxCxDxHxW] +# Args: +# in_channels: dimension of input channels. +# out_channels: dimension of output channels. +# feature_size: dimension of network feature size. +# depths: number of layers in each stage. +# num_heads: number of attention heads. +# norm_name: feature normalization type and arguments. +# drop_rate: dropout rate. +# attn_drop_rate: attention dropout rate. +# dropout_path_rate: drop path rate. +# normalize: normalize output intermediate features in each stage. +# use_checkpoint: use gradient checkpointing for reduced memory usage. +# spatial_dims: number of spatial dims. +# downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a +# user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. +# The default is currently `"merging"` (the original version defined in v0.9.0). + + +# """ + +# super().__init__() + +# if not (spatial_dims == 2 or spatial_dims == 3): +# raise ValueError("spatial dimension should be 2 or 3.") + +# self.normalize = normalize + +# self.swinViT = SwinTransformer( +# in_chans=in_channels, +# embed_dim=feature_size, +# window_size=window_size, +# patch_size=patch_size, +# depths=depths, +# num_heads=num_heads, +# mlp_ratio=4.0, +# qkv_bias=True, +# drop_rate=drop_rate, +# attn_drop_rate=attn_drop_rate, +# drop_path_rate=dropout_path_rate, +# norm_layer=nn.LayerNorm, +# use_checkpoint=use_checkpoint, +# spatial_dims=spatial_dims, +# downsample=look_up_option(downsample) if isinstance( +# downsample, str) else downsample, +# ) +# self.norm = nn.LayerNorm(feature_size*16) +# self.avgpool = nn.AdaptiveAvgPool3d([1, 1, 1]) +# self.dim_reduction = nn.Conv3d(feature_size*16, out_channels, 1) + +# def forward(self, x_in): +# hidden_states_out = self.swinViT(x_in, self.normalize) +# hidden_output = rearrange( +# hidden_states_out[4], "b c d h w -> b d h w c") +# nomalized_hidden_states_out = self.norm(hidden_output) +# nomalized_hidden_states_out = rearrange( +# nomalized_hidden_states_out, "b d h w c -> b c d h w") +# output = self.avgpool(nomalized_hidden_states_out) +# output = torch.flatten(self.dim_reduction(output), 1) +# # print(output.shape) + +# return output + + + +# class CNNRadiologyModel(nn.Module): +# def __init__( +# self, +# in_channels: int, +# out_channels: int, +# feature_size: int = 24, +# spatial_dims: int = 3, +# dropout_rate: float = 0.7, +# ) -> None: +# """ +# A simple 3D CNN-based feature extractor. +# Input requirement : [BxCxDxHxW] +# Args: +# in_channels: dimension of input channels. +# out_channels: dimension of output channels. +# feature_size: dimension of network feature size. +# spatial_dims: number of spatial dims (2 or 3). +# dropout_rate: dropout rate. +# """ +# super().__init__() + +# if spatial_dims != 3: +# raise ValueError("This implementation is designed for 3D inputs.") + +# self.conv1 = nn.Conv3d(in_channels, feature_size, kernel_size=3, padding=1) +# self.conv2 = nn.Conv3d(feature_size, feature_size * 2, kernel_size=3, padding=1) +# self.conv3 = nn.Conv3d(feature_size * 2, feature_size * 4, kernel_size=3, padding=1) +# self.conv4 = nn.Conv3d(feature_size * 4, feature_size * 8, kernel_size=3, padding=1) + +# self.bn1 = nn.BatchNorm3d(feature_size) +# self.bn2 = nn.BatchNorm3d(feature_size * 2) +# self.bn3 = nn.BatchNorm3d(feature_size * 4) +# self.bn4 = nn.BatchNorm3d(feature_size * 8) + +# self.relu = nn.ReLU(inplace=True) +# self.dropout = nn.Dropout(dropout_rate) +# self.pool = nn.AdaptiveAvgPool3d(1) +# self.fc = nn.Conv3d(feature_size * 8, out_channels, kernel_size=1) + +# def forward(self, x): +# x = self.relu(self.bn1(self.conv1(x))) +# x = self.relu(self.bn2(self.conv2(x))) +# x = self.relu(self.bn3(self.conv3(x))) +# x = self.relu(self.bn4(self.conv4(x))) +# x = self.dropout(x) +# x = self.pool(x) +# x = torch.flatten(self.fc(x), 1) +# #print(x.shape) +# return x + +class VisionTransformerRadiologyModel(nn.Module): + def __init__( + self, + patch_size: Union[Sequence[int], int], + in_channels: int, + out_channels: int, + emb_size: int, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + dropout_rate: float = 0.1, + ) -> None: + super().__init__() + + self.vit = VisionTransformer( + in_channels=in_channels, + patch_size=patch_size, + emb_size=emb_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout_rate=dropout_rate + ) + self.fc = nn.Linear(emb_size, out_channels) + + def forward(self, x): + x = self.vit(x) + x = self.fc(x) + return x + +class Model(nn.Module): + def __init__(self, args): + super(Model, self).__init__() + # self.extractor_ct_tumor = SwinTransformerRadiologyModel( + # patch_size=(1, 2, 2), + # window_size=[[4, 4, 4], [4, 4, 4], [8, 8, 8], [4, 4, 4]], + # in_channels=4, + # out_channels=args.feature_size, + # depths=(2, 2, 2, 2), + # num_heads=(3, 6, 12, 24), + # feature_size=int(args.feature_size/2), + # norm_name="instance", + # drop_rate=0.7, + # attn_drop_rate=0., + # dropout_path_rate=0.2, + # normalize=True, + # use_checkpoint=False, + # spatial_dims=3 + # ) + # self.extractor_ct_lymph = SwinTransformerRadiologyModel( + # patch_size=(1, 2, 2), + # window_size=[[4, 4, 4], [4, 4, 4], [8, 8, 8], [4, 4, 4]], + # in_channels=4, + # out_channels=args.feature_size, + # depths=(2, 2, 2, 2), + # num_heads=(3, 6, 12, 24), + # feature_size=int(args.feature_size/2), + # norm_name="instance", + # drop_rate=0.7, + # attn_drop_rate=0., + # dropout_path_rate=0.2, + # normalize=True, + # use_checkpoint=False, + # spatial_dims=3 + # ) + # self.extractor_ct_tumor = CNNRadiologyModel( + # in_channels=4, + # out_channels=args.feature_size + # ) + # self.extractor_ct_lymph = CNNRadiologyModel( + # in_channels=4, + # out_channels=args.feature_size + # ) + + self.extractor_ct_tumor = VisionTransformerRadiologyModel( + patch_size=(2, 2, 2), + in_channels=4, + out_channels=args.feature_size, + emb_size=int(args.feature_size/2), + depth=12, + num_heads=12, + mlp_ratio=4.0, + dropout_rate=0.1 + ) + self.extractor_ct_lymph = VisionTransformerRadiologyModel( + patch_size=(2, 2, 2), + in_channels=4, + out_channels=args.feature_size, + emb_size=int(args.feature_size/2), + depth=12, + num_heads=12, + mlp_ratio=4.0, + dropout_rate=0.1 + ) + + self.fusion = FusionModelBi(args, args.feature_size, args.dim_out) + + def forward(self, ct_tumor, ct_lymph): + features_tumor = self.extractor_ct_tumor(ct_tumor) + features_lymph = self.extractor_ct_lymph(ct_lymph) + + output = self.fusion(features_tumor, features_lymph) + return output +