import torch
import torch.nn as nn
from utils import init_max_weights
class BilinearFusion(nn.Module):
def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, dim1=32, dim2=32, scale_dim1=1, scale_dim2=1, mmhid=64, dropout_rate=0.25):
super(BilinearFusion, self).__init__()
self.skip = skip
self.use_bilinear = use_bilinear
self.gate1 = gate1
self.gate2 = gate2
dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
skip_dim = dim1+dim2+2 if skip else 0
self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.post_fusion_dropout = nn.Dropout(p=dropout_rate)
self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
init_max_weights(self)
def forward(self, vec1, vec2):
### Gated Multimodal Units
if self.gate1:
h1 = self.linear_h1(vec1)
z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
else:
o1 = self.linear_o1(vec1)
if self.gate2:
h2 = self.linear_h2(vec2)
z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
else:
o2 = self.linear_o2(vec2)
### Fusion
o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) # BATCH_SIZE X 1024
out = self.post_fusion_dropout(o12)
out = self.encoder1(out)
if self.skip: out = torch.cat((out, o1, o2), 1)
out = self.encoder2(out)
return out
class TrilinearFusion_A(nn.Module):
def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
super(TrilinearFusion_A, self).__init__()
self.skip = skip
self.use_bilinear = use_bilinear
self.gate1 = gate1
self.gate2 = gate2
self.gate3 = gate3
dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1//scale_dim1, dim2//scale_dim2, dim3//scale_dim3
skip_dim = dim1+dim2+dim3+3 if skip else 0
### Path
self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim1))
self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
### Graph
self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
self.linear_z2 = nn.Bilinear(dim2_og, dim3_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim2_og+dim3_og, dim2))
self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
### Omic
self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim3))
self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.post_fusion_dropout = nn.Dropout(p=0.25)
self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1)*(dim3+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
init_max_weights(self)
def forward(self, vec1, vec2, vec3):
### Gated Multimodal Units
if self.gate1:
h1 = self.linear_h1(vec1)
z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic
o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
else:
o1 = self.linear_o1(vec1)
if self.gate2:
h2 = self.linear_h2(vec2)
z2 = self.linear_z2(vec2, vec3) if self.use_bilinear else self.linear_z2(torch.cat((vec2, vec3), dim=1)) # Gate Graph with Omic
o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
else:
o2 = self.linear_o2(vec2)
if self.gate3:
h3 = self.linear_h3(vec3)
z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3(torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path
o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
else:
o3 = self.linear_o3(vec3)
### Fusion
o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1)
o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1)
o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1)
out = self.post_fusion_dropout(o123)
out = self.encoder1(out)
if self.skip: out = torch.cat((out, o1, o2, o3), 1)
out = self.encoder2(out)
return out
class TrilinearFusion_B(nn.Module):
def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
super(TrilinearFusion_B, self).__init__()
self.skip = skip
self.use_bilinear = use_bilinear
self.gate1 = gate1
self.gate2 = gate2
self.gate3 = gate3
dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1//scale_dim1, dim2//scale_dim2, dim3//scale_dim3
skip_dim = dim1+dim2+dim3+3 if skip else 0
### Path
self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim1))
self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
### Graph
self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
self.linear_z2 = nn.Bilinear(dim2_og, dim1_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim2_og+dim1_og, dim2))
self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
### Omic
self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim3))
self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.post_fusion_dropout = nn.Dropout(p=0.25)
self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1)*(dim3+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
init_max_weights(self)
def forward(self, vec1, vec2, vec3):
### Gated Multimodal Units
if self.gate1:
h1 = self.linear_h1(vec1)
z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic
o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
else:
o1 = self.linear_o1(vec1)
if self.gate2:
h2 = self.linear_h2(vec2)
z2 = self.linear_z2(vec2, vec1) if self.use_bilinear else self.linear_z2(torch.cat((vec2, vec1), dim=1)) # Gate Graph with Omic
o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
else:
o2 = self.linear_o2(vec2)
if self.gate3:
h3 = self.linear_h3(vec3)
z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3(torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path
o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
else:
o3 = self.linear_o3(vec3)
### Fusion
o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1)
o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1)
o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1)
out = self.post_fusion_dropout(o123)
out = self.encoder1(out)
if self.skip: out = torch.cat((out, o1, o2, o3), 1)
out = self.encoder2(out)
return out