Diff of /models/model_set_mil.py [000000] .. [405115]

Switch to side-by-side view

--- a
+++ b/models/model_set_mil.py
@@ -0,0 +1,297 @@
+from collections import OrderedDict
+from os.path import join
+import pdb
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from models.model_utils import *
+
+
+
+################################
+### Deep Sets Implementation ###
+################################
+class MIL_Sum_FC_surv(nn.Module):
+    def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4):
+        r"""
+        Deep Sets Implementation.
+
+        Args:
+            omic_input_dim (int): Dimension size of genomic features.
+            fusion (str): Fusion method (Choices: concat, bilinear, or None)
+            size_arg (str): Size of NN architecture (Choices: small or large)
+            dropout (float): Dropout rate
+            n_classes (int): Output shape of NN
+        """
+        super(MIL_Sum_FC_surv, self).__init__()
+        self.fusion = fusion
+        self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
+        self.size_dict_omic = {'small': [256, 256]}
+
+        ### Deep Sets Architecture Construction
+        size = self.size_dict_path[size_arg]
+        self.phi = nn.Sequential(*[nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)])
+        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
+
+        ### Constructing Genomic SNN
+        if self.fusion != None:
+            hidden = [256, 256]
+            fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])]
+            for i, _ in enumerate(hidden[1:]):
+                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
+            self.fc_omic = nn.Sequential(*fc_omic)
+        
+            if self.fusion == 'concat':
+                self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
+            elif self.fusion == 'bilinear':
+                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
+            else:
+                self.mm = None
+
+        self.classifier = nn.Linear(size[2], n_classes)
+
+
+    def relocate(self):
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        if torch.cuda.device_count() >= 1:
+            device_ids = list(range(torch.cuda.device_count()))
+            self.phi = nn.DataParallel(self.phi, device_ids=device_ids).to('cuda:0')
+
+        if self.fusion is not None:
+            self.fc_omic = self.fc_omic.to(device)
+            self.mm = self.mm.to(device)
+
+        self.rho = self.rho.to(device)
+        self.classifier = self.classifier.to(device)
+
+
+    def forward(self, **kwargs):
+        x_path = kwargs['x_path']
+
+        h_path = self.phi(x_path).sum(axis=0)
+        h_path = self.rho(h_path)
+
+        if self.fusion is not None:
+            x_omic = kwargs['x_omic']
+            h_omic = self.fc_omic(x_omic).squeeze(dim=0)
+            if self.fusion == 'bilinear':
+                h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
+            elif self.fusion == 'concat':
+                h = self.mm(torch.cat([h_path, h_omic], axis=0))
+        else:
+            h = h_path # [256] vector
+
+        logits  = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 
+        Y_hat = torch.topk(logits, 1, dim = 1)[1]
+        hazards = torch.sigmoid(logits)
+        S = torch.cumprod(1 - hazards, dim=1)
+        
+        return hazards, S, Y_hat, None, None
+
+
+
+################################
+# Attention MIL Implementation #
+################################
+class MIL_Attention_FC_surv(nn.Module):
+    def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4):
+        r"""
+        Attention MIL Implementation
+
+        Args:
+            omic_input_dim (int): Dimension size of genomic features.
+            fusion (str): Fusion method (Choices: concat, bilinear, or None)
+            size_arg (str): Size of NN architecture (Choices: small or large)
+            dropout (float): Dropout rate
+            n_classes (int): Output shape of NN
+        """
+        super(MIL_Attention_FC_surv, self).__init__()
+        self.fusion = fusion
+        self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
+        self.size_dict_omic = {'small': [256, 256]}
+
+        ### Deep Sets Architecture Construction
+        size = self.size_dict_path[size_arg]
+        fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
+        attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
+        fc.append(attention_net)
+        self.attention_net = nn.Sequential(*fc)
+        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
+
+        ### Constructing Genomic SNN
+        if self.fusion is not None:
+            hidden = [256, 256]
+            fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])]
+            for i, _ in enumerate(hidden[1:]):
+                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
+            self.fc_omic = nn.Sequential(*fc_omic)
+        
+            if self.fusion == 'concat':
+                self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
+            elif self.fusion == 'bilinear':
+                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
+            else:
+                self.mm = None
+
+        self.classifier = nn.Linear(size[2], n_classes)
+
+
+    def relocate(self):
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        if torch.cuda.device_count() >= 1:
+            device_ids = list(range(torch.cuda.device_count()))
+            self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
+
+        if self.fusion is not None:
+            self.fc_omic = self.fc_omic.to(device)
+            self.mm = self.mm.to(device)
+
+        self.rho = self.rho.to(device)
+        self.classifier = self.classifier.to(device)
+
+
+    def forward(self, **kwargs):
+        x_path = kwargs['x_path']
+
+        A, h_path = self.attention_net(x_path)  
+        A = torch.transpose(A, 1, 0)
+        A_raw = A 
+        A = F.softmax(A, dim=1) 
+        h_path = torch.mm(A, h_path)
+        h_path = self.rho(h_path).squeeze()
+
+        if self.fusion is not None:
+            x_omic = kwargs['x_omic']
+            h_omic = self.fc_omic(x_omic)
+            if self.fusion == 'bilinear':
+                h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
+            elif self.fusion == 'concat':
+                h = self.mm(torch.cat([h_path, h_omic], axis=0))
+        else:
+            h = h_path # [256] vector
+
+        logits  = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 
+        Y_hat = torch.topk(logits, 1, dim = 1)[1]
+        hazards = torch.sigmoid(logits)
+        S = torch.cumprod(1 - hazards, dim=1)
+        
+        return hazards, S, Y_hat, None, None
+
+
+
+######################################
+# Deep Attention MISL Implementation #
+######################################
+class MIL_Cluster_FC_surv(nn.Module):
+    def __init__(self, omic_input_dim=None, fusion=None, num_clusters=10, size_arg = "small", dropout=0.25, n_classes=4):
+        r"""
+        Attention MIL Implementation
+
+        Args:
+            omic_input_dim (int): Dimension size of genomic features.
+            fusion (str): Fusion method (Choices: concat, bilinear, or None)
+            size_arg (str): Size of NN architecture (Choices: small or large)
+            dropout (float): Dropout rate
+            n_classes (int): Output shape of NN
+        """
+        super(MIL_Cluster_FC_surv, self).__init__()
+        self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
+        self.size_dict_omic = {'small': [256, 256]}
+        self.num_clusters = num_clusters
+        self.fusion = fusion
+        
+        ### FC Cluster layers + Pooling
+        size = self.size_dict_path[size_arg]
+        phis = []
+        for phenotype_i in range(num_clusters):
+            phi = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout),
+                   nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)]
+            phis.append(nn.Sequential(*phi))
+        self.phis = nn.ModuleList(phis)
+        self.pool1d = nn.AdaptiveAvgPool1d(1)
+        
+        ### WSI Attention MIL Construction
+        fc = [nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)]
+        attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
+        fc.append(attention_net)
+        self.attention_net = nn.Sequential(*fc)
+        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
+
+        ### Genomic SNN Construction + Multimodal Fusion
+        if fusion is not None:
+            hidden = self.size_dict_omic['small']
+            fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])]
+            for i, _ in enumerate(hidden[1:]):
+                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
+            self.fc_omic = nn.Sequential(*fc_omic)
+
+            if fusion == 'concat':
+                self.mm = nn.Sequential(*[nn.Linear(size[2]*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
+            elif self.fusion == 'bilinear':
+                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
+            else:
+                self.mm = None
+
+        self.classifier = nn.Linear(size[2], n_classes)
+
+
+    def relocate(self):
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        if torch.cuda.device_count() >= 1:
+            device_ids = list(range(torch.cuda.device_count()))
+            self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
+        else:
+            self.attention_net = self.attention_net.to(device)
+
+        if self.fusion is not None:
+            self.fc_omic = self.fc_omic.to(device)
+            self.mm = self.mm.to(device)
+
+        self.phis = self.phis.to(device)
+        self.pool1d = self.pool1d.to(device)
+        self.rho = self.rho.to(device)
+        self.classifier = self.classifier.to(device)
+
+
+    def forward(self, **kwargs):
+        x_path = kwargs['x_path']
+        cluster_id = kwargs['cluster_id'].detach().cpu().numpy()
+
+        ### FC Cluster layers + Pooling
+        h_cluster = []
+        for i in range(self.num_clusters):
+            h_cluster_i = self.phis[i](x_path[cluster_id==i])
+            if h_cluster_i.shape[0] == 0:
+                h_cluster_i = torch.zeros((1,512)).to(torch.device('cuda'))
+            h_cluster.append(self.pool1d(h_cluster_i.T.unsqueeze(0)).squeeze(2))
+        h_cluster = torch.stack(h_cluster, dim=1).squeeze(0)
+
+        ### Attention MIL
+        A, h_path = self.attention_net(h_cluster)  
+        A = torch.transpose(A, 1, 0)
+        A_raw = A 
+        A = F.softmax(A, dim=1) 
+        h_path = torch.mm(A, h_path)
+        h_path = self.rho(h_path).squeeze()
+
+        ### Attention MIL + Genomic Fusion
+        if self.fusion is not None:
+            x_omic = kwargs['x_omic']
+            h_omic = self.fc_omic(x_omic)
+            if self.fusion == 'bilinear':
+                h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
+            elif self.fusion == 'concat':
+                h = self.mm(torch.cat([h_path, h_omic], axis=0))
+        else:
+            h = h_path
+
+        logits  = self.classifier(h).unsqueeze(0)
+        Y_hat = torch.topk(logits, 1, dim = 1)[1]
+        hazards = torch.sigmoid(logits)
+        S = torch.cumprod(1 - hazards, dim=1)
+        
+        return hazards, S, Y_hat, None, None
\ No newline at end of file