Diff of /models/model_porpoise.py [000000] .. [405115]

Switch to side-by-side view

--- a
+++ b/models/model_porpoise.py
@@ -0,0 +1,394 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import pdb
+import numpy as np
+from os.path import join
+from collections import OrderedDict
+
+class LRBilinearFusion(nn.Module):
+    def __init__(self, skip=0, use_bilinear=0, gate1=1, gate2=1, dim1=128, dim2=128, 
+                 scale_dim1=1, scale_dim2=1, dropout_rate=0.25,
+                rank=16, output_dim=4):
+        super(LRBilinearFusion, self).__init__()
+        self.skip = skip
+        self.use_bilinear = use_bilinear
+        self.gate1 = gate1
+        self.gate2 = gate2
+        self.rank = rank
+        self.output_dim = output_dim
+
+        dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
+        skip_dim = dim1_og+dim2_og if skip else 0
+
+        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
+        self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
+        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
+        self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
+        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        
+        
+        self.h1_factor = Parameter(torch.Tensor(self.rank, dim1 + 1, output_dim))
+        self.h2_factor = Parameter(torch.Tensor(self.rank, dim2 + 1, output_dim))
+        self.fusion_weights = Parameter(torch.Tensor(1, self.rank))
+        self.fusion_bias = Parameter(torch.Tensor(1, self.output_dim))
+        xavier_normal(self.h1_factor)
+        xavier_normal(self.h2_factor)
+        xavier_normal(self.fusion_weights)
+        self.fusion_bias.data.fill_(0)
+
+        #init_max_weights(self)
+
+    def forward(self, vec1, vec2):
+        ### Gated Multimodal Units
+        if self.gate1:
+            h1 = self.linear_h1(vec1)
+            z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
+            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
+        else:
+            h1 = F.dropout(self.linear_h1(vec1), 0.25)
+            o1 = self.linear_o1(h1)
+
+        if self.gate2:
+            h2 = self.linear_h2(vec2)
+            z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
+            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
+        else:
+            h2 = F.dropout(self.linear_h2(vec2), 0.25)
+            o2 = self.linear_o2(h2)
+
+        ### Fusion
+        DTYPE = torch.cuda.FloatTensor
+        _o1 = torch.cat((Variable(torch.ones(1, 1).type(DTYPE), requires_grad=False), o1), dim=1)
+        _o2 = torch.cat((Variable(torch.ones(1, 1).type(DTYPE), requires_grad=False), o2), dim=1)
+        o1_fusion = torch.matmul(_o1, self.h1_factor)
+        o2_fusion = torch.matmul(_o2, self.h2_factor)
+        fusion_zy = o1_fusion * o2_fusion
+        output = torch.matmul(self.fusion_weights, fusion_zy.permute(1, 0, 2)).squeeze() + self.fusion_bias
+        output = output.view(-1, self.output_dim)
+        return output
+
+class BilinearFusion(nn.Module):
+    def __init__(self, skip=0, use_bilinear=0, gate1=1, gate2=1, dim1=128, dim2=128, scale_dim1=1, scale_dim2=1, mmhid=256, dropout_rate=0.25):
+        super(BilinearFusion, self).__init__()
+        self.skip = skip
+        self.use_bilinear = use_bilinear
+        self.gate1 = gate1
+        self.gate2 = gate2
+
+        dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
+        skip_dim = dim1_og+dim2_og if skip else 0
+
+        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
+        self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
+        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
+        self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
+        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.post_fusion_dropout = nn.Dropout(p=dropout_rate)
+        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1), 256), nn.ReLU())
+        self.encoder2 = nn.Sequential(nn.Linear(256+skip_dim, mmhid), nn.ReLU())
+        #init_max_weights(self)
+
+    def forward(self, vec1, vec2):
+        ### Gated Multimodal Units
+        if self.gate1:
+            h1 = self.linear_h1(vec1)
+            z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
+            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
+        else:
+            h1 = self.linear_h1(vec1)
+            o1 = self.linear_o1(h1)
+
+        if self.gate2:
+            h2 = self.linear_h2(vec2)
+            z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
+            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
+        else:
+            h2 = self.linear_h2(vec2)
+            o2 = self.linear_o2(h2)
+
+        ### Fusion
+        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
+        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
+        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) # BATCH_SIZE X 1024
+        out = self.post_fusion_dropout(o12)
+        out = self.encoder1(out)
+        if self.skip: out = torch.cat((out, vec1, vec2), 1)
+        out = self.encoder2(out)
+        return out
+
+def SNN_Block(dim1, dim2, dropout=0.25):
+    return nn.Sequential(
+            nn.Linear(dim1, dim2),
+            nn.ELU(),
+            nn.AlphaDropout(p=dropout, inplace=False))
+
+
+def MLP_Block(dim1, dim2, dropout=0.25):
+    return nn.Sequential(
+            nn.Linear(dim1, dim2),
+            nn.ReLU(),
+            nn.Dropout(p=dropout, inplace=False))
+
+
+"""
+Attention Network without Gating (2 fc layers)
+args:
+    L: input feature dimension
+    D: hidden layer dimension
+    dropout: whether to use dropout (p = 0.25)
+    n_classes: number of classes (experimental usage for multiclass MIL)
+"""
+class Attn_Net(nn.Module):
+
+    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
+        super(Attn_Net, self).__init__()
+        self.module = [
+            nn.Linear(L, D),
+            nn.Tanh()]
+
+        if dropout:
+            self.module.append(nn.Dropout(0.25))
+
+        self.module.append(nn.Linear(D, n_classes))
+        
+        self.module = nn.Sequential(*self.module)
+    
+    def forward(self, x):
+        return self.module(x), x # N x n_classes
+
+"""
+Attention Network with Sigmoid Gating (3 fc layers)
+args:
+    L: input feature dimension
+    D: hidden layer dimension
+    dropout: whether to use dropout (p = 0.25)
+    n_classes: number of classes (experimental usage for multiclass MIL)
+"""
+class Attn_Net_Gated(nn.Module):
+
+    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
+        super(Attn_Net_Gated, self).__init__()
+        self.attention_a = [
+            nn.Linear(L, D),
+            nn.Tanh()]
+        
+        self.attention_b = [nn.Linear(L, D),
+                            nn.Sigmoid()]
+        if dropout:
+            self.attention_a.append(nn.Dropout(0.25))
+            self.attention_b.append(nn.Dropout(0.25))
+
+        self.attention_a = nn.Sequential(*self.attention_a)
+        self.attention_b = nn.Sequential(*self.attention_b)
+        
+        self.attention_c = nn.Linear(D, n_classes)
+
+    def forward(self, x):
+        a = self.attention_a(x)
+        b = self.attention_b(x)
+        A = a.mul(b)
+        A = self.attention_c(A)  # N x n_classes
+        return A, x
+
+
+
+
+
+
+
+"""
+
+"""
+
+def initialize_weights(module):
+    for m in module.modules():
+        if isinstance(m, nn.Linear):
+            nn.init.xavier_normal_(m.weight)
+            m.bias.data.zero_()
+        
+        elif isinstance(m, nn.BatchNorm1d):
+            nn.init.constant_(m.weight, 1)
+            nn.init.constant_(m.bias, 0)
+
+
+class PorpoiseAMIL(nn.Module):
+    def __init__(self, size_arg = "small", n_classes=4):
+        super(PorpoiseAMIL, self).__init__()
+        self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
+        size = self.size_dict[size_arg]
+        
+        fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(0.25)]
+        attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=0.25, n_classes=1)
+        fc.append(attention_net)
+        self.attention_net = nn.Sequential(*fc)
+        
+        self.classifier = nn.Linear(size[1], n_classes)
+        initialize_weights(self)
+                
+                
+    def relocate(self):
+        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        if torch.cuda.device_count() > 1:
+            device_ids = list(range(torch.cuda.device_count()))
+            self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
+        else:
+            self.attention_net = self.attention_net.to(device)
+
+        self.classifier = self.classifier.to(device)
+
+
+    def forward(self, **kwargs):
+        h = kwargs['x_path']
+
+        A, h = self.attention_net(h)  
+        A = torch.transpose(A, 1, 0)
+
+        if 'attention_only' in kwargs.keys():
+            if kwargs['attention_only']:
+                return A
+
+        A_raw = A 
+        A = F.softmax(A, dim=1) 
+        M = torch.mm(A, h) 
+        h  = self.classifier(M)
+        return h
+
+    def get_slide_features(self, **kwargs):
+        h = kwargs['x_path']
+
+        A, h = self.attention_net(h)  
+        A = torch.transpose(A, 1, 0)
+
+        if 'attention_only' in kwargs.keys():
+            if kwargs['attention_only']:
+                return A
+
+        A_raw = A 
+        A = F.softmax(A, dim=1) 
+        M = torch.mm(A, h) 
+        return M
+
+
+### MMF (in the PORPOISE Paper)
+class PorpoiseMMF(nn.Module):
+    def __init__(self, 
+        omic_input_dim,
+        path_input_dim=1024, 
+        fusion='bilinear', 
+        dropout=0.25,
+        n_classes=4, 
+        scale_dim1=8, 
+        scale_dim2=8, 
+        gate_path=1, 
+        gate_omic=1, 
+        skip=True, 
+        dropinput=0.10,
+        use_mlp=False,
+        size_arg = "small",
+        ):
+        super(PorpoiseMMF, self).__init__()
+        self.fusion = fusion
+        self.size_dict_path = {"small": [path_input_dim, 512, 256], "big": [1024, 512, 384]}
+        self.size_dict_omic = {'small': [256, 256]}
+        self.n_classes = n_classes
+
+        ### Deep Sets Architecture Construction
+        size = self.size_dict_path[size_arg]
+        if dropinput:
+            fc = [nn.Dropout(dropinput), nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
+        else:
+            fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
+        attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
+        fc.append(attention_net)
+        self.attention_net = nn.Sequential(*fc)
+        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
+
+        ### Constructing Genomic SNN
+        if self.fusion is not None:
+            if use_mlp:
+                Block = MLP_Block
+            else:
+                Block = SNN_Block
+
+            hidden = self.size_dict_omic['small']
+            fc_omic = [Block(dim1=omic_input_dim, dim2=hidden[0])]
+            for i, _ in enumerate(hidden[1:]):
+                fc_omic.append(Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
+            self.fc_omic = nn.Sequential(*fc_omic)
+        
+            if self.fusion == 'concat':
+                self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
+            elif self.fusion == 'bilinear':
+                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=scale_dim1, gate1=gate_path, scale_dim2=scale_dim2, gate2=gate_omic, skip=skip, mmhid=256)
+            elif self.fusion == 'lrb':
+                self.mm = LRBilinearFusion(dim1=256, dim2=256, scale_dim1=scale_dim1, gate1=gate_path, scale_dim2=scale_dim2, gate2=gate_omic)
+            else:
+                self.mm = None
+
+        self.classifier_mm = nn.Linear(size[2], n_classes)
+
+
+    def relocate(self):
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        if torch.cuda.device_count() >= 1:
+            device_ids = list(range(torch.cuda.device_count()))
+            self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
+
+        if self.fusion is not None:
+            self.fc_omic = self.fc_omic.to(device)
+            self.mm = self.mm.to(device)
+
+        self.rho = self.rho.to(device)
+        self.classifier_mm = self.classifier_mm.to(device)
+
+    def forward(self, **kwargs):
+        x_path = kwargs['x_path']
+        A, h_path = self.attention_net(x_path)  
+        A = torch.transpose(A, 1, 0)
+        A_raw = A 
+        A = F.softmax(A, dim=1) 
+        h_path = torch.mm(A, h_path)
+        h_path = self.rho(h_path)
+
+        x_omic = kwargs['x_omic']
+        h_omic = self.fc_omic(x_omic)
+        if self.fusion == 'bilinear':
+            h_mm = self.mm(h_path, h_omic)
+        elif self.fusion == 'concat':
+            h_mm = self.mm(torch.cat([h_path, h_omic], axis=1))
+        elif self.fusion == 'lrb':
+            h_mm  = self.mm(h_path, h_omic) # logits needs to be a [1 x 4] vector 
+            return h_mm
+
+        h_mm  = self.classifier_mm(h_mm) # logits needs to be a [B x 4] vector      
+        assert len(h_mm.shape) == 2 and h_mm.shape[1] == self.n_classes
+
+
+        return h_mm
+
+    def captum(self, h, X):
+        A, h = self.attention_net(h)  
+        A = A.squeeze(dim=2)
+
+        A = F.softmax(A, dim=1) 
+        M = torch.bmm(A.unsqueeze(dim=1), h).squeeze(dim=1) #M = torch.mm(A, h)
+        M = self.rho(M)
+        O = self.fc_omic(X)
+
+        if self.fusion == 'bilinear':
+            MM = self.mm(M, O)
+        elif self.fusion == 'concat':
+            MM = self.mm(torch.cat([M, O], axis=1))
+            
+        logits  = self.classifier(MM)
+        hazards = torch.sigmoid(logits)
+        S = torch.cumprod(1 - hazards, dim=1)
+
+        risk = -torch.sum(S, dim=1)
+        return risk
\ No newline at end of file