import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
import numpy as np
from os.path import join
from collections import OrderedDict
class LRBilinearFusion(nn.Module):
def __init__(self, skip=0, use_bilinear=0, gate1=1, gate2=1, dim1=128, dim2=128,
scale_dim1=1, scale_dim2=1, dropout_rate=0.25,
rank=16, output_dim=4):
super(LRBilinearFusion, self).__init__()
self.skip = skip
self.use_bilinear = use_bilinear
self.gate1 = gate1
self.gate2 = gate2
self.rank = rank
self.output_dim = output_dim
dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
skip_dim = dim1_og+dim2_og if skip else 0
self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.h1_factor = Parameter(torch.Tensor(self.rank, dim1 + 1, output_dim))
self.h2_factor = Parameter(torch.Tensor(self.rank, dim2 + 1, output_dim))
self.fusion_weights = Parameter(torch.Tensor(1, self.rank))
self.fusion_bias = Parameter(torch.Tensor(1, self.output_dim))
xavier_normal(self.h1_factor)
xavier_normal(self.h2_factor)
xavier_normal(self.fusion_weights)
self.fusion_bias.data.fill_(0)
#init_max_weights(self)
def forward(self, vec1, vec2):
### Gated Multimodal Units
if self.gate1:
h1 = self.linear_h1(vec1)
z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
else:
h1 = F.dropout(self.linear_h1(vec1), 0.25)
o1 = self.linear_o1(h1)
if self.gate2:
h2 = self.linear_h2(vec2)
z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
else:
h2 = F.dropout(self.linear_h2(vec2), 0.25)
o2 = self.linear_o2(h2)
### Fusion
DTYPE = torch.cuda.FloatTensor
_o1 = torch.cat((Variable(torch.ones(1, 1).type(DTYPE), requires_grad=False), o1), dim=1)
_o2 = torch.cat((Variable(torch.ones(1, 1).type(DTYPE), requires_grad=False), o2), dim=1)
o1_fusion = torch.matmul(_o1, self.h1_factor)
o2_fusion = torch.matmul(_o2, self.h2_factor)
fusion_zy = o1_fusion * o2_fusion
output = torch.matmul(self.fusion_weights, fusion_zy.permute(1, 0, 2)).squeeze() + self.fusion_bias
output = output.view(-1, self.output_dim)
return output
class BilinearFusion(nn.Module):
def __init__(self, skip=0, use_bilinear=0, gate1=1, gate2=1, dim1=128, dim2=128, scale_dim1=1, scale_dim2=1, mmhid=256, dropout_rate=0.25):
super(BilinearFusion, self).__init__()
self.skip = skip
self.use_bilinear = use_bilinear
self.gate1 = gate1
self.gate2 = gate2
dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
skip_dim = dim1_og+dim2_og if skip else 0
self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
self.post_fusion_dropout = nn.Dropout(p=dropout_rate)
self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1), 256), nn.ReLU())
self.encoder2 = nn.Sequential(nn.Linear(256+skip_dim, mmhid), nn.ReLU())
#init_max_weights(self)
def forward(self, vec1, vec2):
### Gated Multimodal Units
if self.gate1:
h1 = self.linear_h1(vec1)
z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
else:
h1 = self.linear_h1(vec1)
o1 = self.linear_o1(h1)
if self.gate2:
h2 = self.linear_h2(vec2)
z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
else:
h2 = self.linear_h2(vec2)
o2 = self.linear_o2(h2)
### Fusion
o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) # BATCH_SIZE X 1024
out = self.post_fusion_dropout(o12)
out = self.encoder1(out)
if self.skip: out = torch.cat((out, vec1, vec2), 1)
out = self.encoder2(out)
return out
def SNN_Block(dim1, dim2, dropout=0.25):
return nn.Sequential(
nn.Linear(dim1, dim2),
nn.ELU(),
nn.AlphaDropout(p=dropout, inplace=False))
def MLP_Block(dim1, dim2, dropout=0.25):
return nn.Sequential(
nn.Linear(dim1, dim2),
nn.ReLU(),
nn.Dropout(p=dropout, inplace=False))
"""
Attention Network without Gating (2 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes (experimental usage for multiclass MIL)
"""
class Attn_Net(nn.Module):
def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
super(Attn_Net, self).__init__()
self.module = [
nn.Linear(L, D),
nn.Tanh()]
if dropout:
self.module.append(nn.Dropout(0.25))
self.module.append(nn.Linear(D, n_classes))
self.module = nn.Sequential(*self.module)
def forward(self, x):
return self.module(x), x # N x n_classes
"""
Attention Network with Sigmoid Gating (3 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes (experimental usage for multiclass MIL)
"""
class Attn_Net_Gated(nn.Module):
def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
super(Attn_Net_Gated, self).__init__()
self.attention_a = [
nn.Linear(L, D),
nn.Tanh()]
self.attention_b = [nn.Linear(L, D),
nn.Sigmoid()]
if dropout:
self.attention_a.append(nn.Dropout(0.25))
self.attention_b.append(nn.Dropout(0.25))
self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)
self.attention_c = nn.Linear(D, n_classes)
def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x n_classes
return A, x
"""
"""
def initialize_weights(module):
for m in module.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class PorpoiseAMIL(nn.Module):
def __init__(self, size_arg = "small", n_classes=4):
super(PorpoiseAMIL, self).__init__()
self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
size = self.size_dict[size_arg]
fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(0.25)]
attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=0.25, n_classes=1)
fc.append(attention_net)
self.attention_net = nn.Sequential(*fc)
self.classifier = nn.Linear(size[1], n_classes)
initialize_weights(self)
def relocate(self):
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
device_ids = list(range(torch.cuda.device_count()))
self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
else:
self.attention_net = self.attention_net.to(device)
self.classifier = self.classifier.to(device)
def forward(self, **kwargs):
h = kwargs['x_path']
A, h = self.attention_net(h)
A = torch.transpose(A, 1, 0)
if 'attention_only' in kwargs.keys():
if kwargs['attention_only']:
return A
A_raw = A
A = F.softmax(A, dim=1)
M = torch.mm(A, h)
h = self.classifier(M)
return h
def get_slide_features(self, **kwargs):
h = kwargs['x_path']
A, h = self.attention_net(h)
A = torch.transpose(A, 1, 0)
if 'attention_only' in kwargs.keys():
if kwargs['attention_only']:
return A
A_raw = A
A = F.softmax(A, dim=1)
M = torch.mm(A, h)
return M
### MMF (in the PORPOISE Paper)
class PorpoiseMMF(nn.Module):
def __init__(self,
omic_input_dim,
path_input_dim=1024,
fusion='bilinear',
dropout=0.25,
n_classes=4,
scale_dim1=8,
scale_dim2=8,
gate_path=1,
gate_omic=1,
skip=True,
dropinput=0.10,
use_mlp=False,
size_arg = "small",
):
super(PorpoiseMMF, self).__init__()
self.fusion = fusion
self.size_dict_path = {"small": [path_input_dim, 512, 256], "big": [1024, 512, 384]}
self.size_dict_omic = {'small': [256, 256]}
self.n_classes = n_classes
### Deep Sets Architecture Construction
size = self.size_dict_path[size_arg]
if dropinput:
fc = [nn.Dropout(dropinput), nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
else:
fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
fc.append(attention_net)
self.attention_net = nn.Sequential(*fc)
self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
### Constructing Genomic SNN
if self.fusion is not None:
if use_mlp:
Block = MLP_Block
else:
Block = SNN_Block
hidden = self.size_dict_omic['small']
fc_omic = [Block(dim1=omic_input_dim, dim2=hidden[0])]
for i, _ in enumerate(hidden[1:]):
fc_omic.append(Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
self.fc_omic = nn.Sequential(*fc_omic)
if self.fusion == 'concat':
self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
elif self.fusion == 'bilinear':
self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=scale_dim1, gate1=gate_path, scale_dim2=scale_dim2, gate2=gate_omic, skip=skip, mmhid=256)
elif self.fusion == 'lrb':
self.mm = LRBilinearFusion(dim1=256, dim2=256, scale_dim1=scale_dim1, gate1=gate_path, scale_dim2=scale_dim2, gate2=gate_omic)
else:
self.mm = None
self.classifier_mm = nn.Linear(size[2], n_classes)
def relocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() >= 1:
device_ids = list(range(torch.cuda.device_count()))
self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
if self.fusion is not None:
self.fc_omic = self.fc_omic.to(device)
self.mm = self.mm.to(device)
self.rho = self.rho.to(device)
self.classifier_mm = self.classifier_mm.to(device)
def forward(self, **kwargs):
x_path = kwargs['x_path']
A, h_path = self.attention_net(x_path)
A = torch.transpose(A, 1, 0)
A_raw = A
A = F.softmax(A, dim=1)
h_path = torch.mm(A, h_path)
h_path = self.rho(h_path)
x_omic = kwargs['x_omic']
h_omic = self.fc_omic(x_omic)
if self.fusion == 'bilinear':
h_mm = self.mm(h_path, h_omic)
elif self.fusion == 'concat':
h_mm = self.mm(torch.cat([h_path, h_omic], axis=1))
elif self.fusion == 'lrb':
h_mm = self.mm(h_path, h_omic) # logits needs to be a [1 x 4] vector
return h_mm
h_mm = self.classifier_mm(h_mm) # logits needs to be a [B x 4] vector
assert len(h_mm.shape) == 2 and h_mm.shape[1] == self.n_classes
return h_mm
def captum(self, h, X):
A, h = self.attention_net(h)
A = A.squeeze(dim=2)
A = F.softmax(A, dim=1)
M = torch.bmm(A.unsqueeze(dim=1), h).squeeze(dim=1) #M = torch.mm(A, h)
M = self.rho(M)
O = self.fc_omic(X)
if self.fusion == 'bilinear':
MM = self.mm(M, O)
elif self.fusion == 'concat':
MM = self.mm(torch.cat([M, O], axis=1))
logits = self.classifier(MM)
hazards = torch.sigmoid(logits)
S = torch.cumprod(1 - hazards, dim=1)
risk = -torch.sum(S, dim=1)
return risk