Diff of /fusion.py [000000] .. [2095ed]

Switch to side-by-side view

--- a
+++ b/fusion.py
@@ -0,0 +1,192 @@
+import torch
+import torch.nn as nn
+from utils import init_max_weights
+
+
+class BilinearFusion(nn.Module):
+    def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, dim1=32, dim2=32, scale_dim1=1, scale_dim2=1, mmhid=64, dropout_rate=0.25):
+        super(BilinearFusion, self).__init__()
+        self.skip = skip
+        self.use_bilinear = use_bilinear
+        self.gate1 = gate1
+        self.gate2 = gate2
+
+        dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
+        skip_dim = dim1+dim2+2 if skip else 0
+
+        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
+        self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
+        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
+        self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
+        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.post_fusion_dropout = nn.Dropout(p=dropout_rate)
+        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        init_max_weights(self)
+
+    def forward(self, vec1, vec2):
+        ### Gated Multimodal Units
+        if self.gate1:
+            h1 = self.linear_h1(vec1)
+            z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
+            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
+        else:
+            o1 = self.linear_o1(vec1)
+
+        if self.gate2:
+            h2 = self.linear_h2(vec2)
+            z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
+            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
+        else:
+            o2 = self.linear_o2(vec2)
+
+        ### Fusion
+        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
+        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
+        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) # BATCH_SIZE X 1024
+        out = self.post_fusion_dropout(o12)
+        out = self.encoder1(out)
+        if self.skip: out = torch.cat((out, o1, o2), 1)
+        out = self.encoder2(out)
+        return out
+
+
+class TrilinearFusion_A(nn.Module):
+    def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
+        super(TrilinearFusion_A, self).__init__()
+        self.skip = skip
+        self.use_bilinear = use_bilinear
+        self.gate1 = gate1
+        self.gate2 = gate2
+        self.gate3 = gate3
+
+        dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1//scale_dim1, dim2//scale_dim2, dim3//scale_dim3
+        skip_dim = dim1+dim2+dim3+3 if skip else 0
+
+        ### Path
+        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
+        self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim1))
+        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        ### Graph
+        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
+        self.linear_z2 = nn.Bilinear(dim2_og, dim3_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim2_og+dim3_og, dim2))
+        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        ### Omic
+        self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
+        self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim3))
+        self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.post_fusion_dropout = nn.Dropout(p=0.25)
+        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1)*(dim3+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        
+        init_max_weights(self)
+
+    def forward(self, vec1, vec2, vec3):
+        ### Gated Multimodal Units
+        if self.gate1:
+            h1 = self.linear_h1(vec1)
+            z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic
+            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
+        else:
+            o1 = self.linear_o1(vec1)
+
+        if self.gate2:
+            h2 = self.linear_h2(vec2)
+            z2 = self.linear_z2(vec2, vec3) if self.use_bilinear else self.linear_z2(torch.cat((vec2, vec3), dim=1)) # Gate Graph with Omic
+            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
+        else:
+            o2 = self.linear_o2(vec2)
+
+        if self.gate3:
+            h3 = self.linear_h3(vec3)
+            z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3(torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path
+            o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
+        else:
+            o3 = self.linear_o3(vec3)
+
+        ### Fusion
+        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
+        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
+        o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1)
+        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1)
+        o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1)
+        out = self.post_fusion_dropout(o123)
+        out = self.encoder1(out)
+        if self.skip: out = torch.cat((out, o1, o2, o3), 1)
+        out = self.encoder2(out)
+        return out
+
+
+class TrilinearFusion_B(nn.Module):
+    def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
+        super(TrilinearFusion_B, self).__init__()
+        self.skip = skip
+        self.use_bilinear = use_bilinear
+        self.gate1 = gate1
+        self.gate2 = gate2
+        self.gate3 = gate3
+
+        dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1//scale_dim1, dim2//scale_dim2, dim3//scale_dim3
+        skip_dim = dim1+dim2+dim3+3 if skip else 0
+
+        ### Path
+        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
+        self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim1))
+        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        ### Graph
+        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
+        self.linear_z2 = nn.Bilinear(dim2_og, dim1_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim2_og+dim1_og, dim2))
+        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        ### Omic
+        self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
+        self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim3))
+        self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate))
+
+        self.post_fusion_dropout = nn.Dropout(p=0.25)
+        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1)*(dim3+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
+        
+        init_max_weights(self)
+
+    def forward(self, vec1, vec2, vec3):
+        ### Gated Multimodal Units
+        if self.gate1:
+            h1 = self.linear_h1(vec1)
+            z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic
+            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
+        else:
+            o1 = self.linear_o1(vec1)
+
+        if self.gate2:
+            h2 = self.linear_h2(vec2)
+            z2 = self.linear_z2(vec2, vec1) if self.use_bilinear else self.linear_z2(torch.cat((vec2, vec1), dim=1)) # Gate Graph with Omic
+            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
+        else:
+            o2 = self.linear_o2(vec2)
+
+        if self.gate3:
+            h3 = self.linear_h3(vec3)
+            z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3(torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path
+            o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
+        else:
+            o3 = self.linear_o3(vec3)
+
+        ### Fusion
+        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
+        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
+        o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1)
+        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1)
+        o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1)
+        out = self.post_fusion_dropout(o123)
+        out = self.encoder1(out)
+        if self.skip: out = torch.cat((out, o1, o2, o3), 1)
+        out = self.encoder2(out)
+        return out
\ No newline at end of file