--- a +++ b/models/model_set_mil.py @@ -0,0 +1,297 @@ +from collections import OrderedDict +from os.path import join +import pdb + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.model_utils import * + + + +################################ +### Deep Sets Implementation ### +################################ +class MIL_Sum_FC_surv(nn.Module): + def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4): + r""" + Deep Sets Implementation. + + Args: + omic_input_dim (int): Dimension size of genomic features. + fusion (str): Fusion method (Choices: concat, bilinear, or None) + size_arg (str): Size of NN architecture (Choices: small or large) + dropout (float): Dropout rate + n_classes (int): Output shape of NN + """ + super(MIL_Sum_FC_surv, self).__init__() + self.fusion = fusion + self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]} + self.size_dict_omic = {'small': [256, 256]} + + ### Deep Sets Architecture Construction + size = self.size_dict_path[size_arg] + self.phi = nn.Sequential(*[nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]) + self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)]) + + ### Constructing Genomic SNN + if self.fusion != None: + hidden = [256, 256] + fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])] + for i, _ in enumerate(hidden[1:]): + fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25)) + self.fc_omic = nn.Sequential(*fc_omic) + + if self.fusion == 'concat': + self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()]) + elif self.fusion == 'bilinear': + self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256) + else: + self.mm = None + + self.classifier = nn.Linear(size[2], n_classes) + + + def relocate(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() >= 1: + device_ids = list(range(torch.cuda.device_count())) + self.phi = nn.DataParallel(self.phi, device_ids=device_ids).to('cuda:0') + + if self.fusion is not None: + self.fc_omic = self.fc_omic.to(device) + self.mm = self.mm.to(device) + + self.rho = self.rho.to(device) + self.classifier = self.classifier.to(device) + + + def forward(self, **kwargs): + x_path = kwargs['x_path'] + + h_path = self.phi(x_path).sum(axis=0) + h_path = self.rho(h_path) + + if self.fusion is not None: + x_omic = kwargs['x_omic'] + h_omic = self.fc_omic(x_omic).squeeze(dim=0) + if self.fusion == 'bilinear': + h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze() + elif self.fusion == 'concat': + h = self.mm(torch.cat([h_path, h_omic], axis=0)) + else: + h = h_path # [256] vector + + logits = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector + Y_hat = torch.topk(logits, 1, dim = 1)[1] + hazards = torch.sigmoid(logits) + S = torch.cumprod(1 - hazards, dim=1) + + return hazards, S, Y_hat, None, None + + + +################################ +# Attention MIL Implementation # +################################ +class MIL_Attention_FC_surv(nn.Module): + def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4): + r""" + Attention MIL Implementation + + Args: + omic_input_dim (int): Dimension size of genomic features. + fusion (str): Fusion method (Choices: concat, bilinear, or None) + size_arg (str): Size of NN architecture (Choices: small or large) + dropout (float): Dropout rate + n_classes (int): Output shape of NN + """ + super(MIL_Attention_FC_surv, self).__init__() + self.fusion = fusion + self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]} + self.size_dict_omic = {'small': [256, 256]} + + ### Deep Sets Architecture Construction + size = self.size_dict_path[size_arg] + fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)] + attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1) + fc.append(attention_net) + self.attention_net = nn.Sequential(*fc) + self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)]) + + ### Constructing Genomic SNN + if self.fusion is not None: + hidden = [256, 256] + fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])] + for i, _ in enumerate(hidden[1:]): + fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25)) + self.fc_omic = nn.Sequential(*fc_omic) + + if self.fusion == 'concat': + self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()]) + elif self.fusion == 'bilinear': + self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256) + else: + self.mm = None + + self.classifier = nn.Linear(size[2], n_classes) + + + def relocate(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() >= 1: + device_ids = list(range(torch.cuda.device_count())) + self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0') + + if self.fusion is not None: + self.fc_omic = self.fc_omic.to(device) + self.mm = self.mm.to(device) + + self.rho = self.rho.to(device) + self.classifier = self.classifier.to(device) + + + def forward(self, **kwargs): + x_path = kwargs['x_path'] + + A, h_path = self.attention_net(x_path) + A = torch.transpose(A, 1, 0) + A_raw = A + A = F.softmax(A, dim=1) + h_path = torch.mm(A, h_path) + h_path = self.rho(h_path).squeeze() + + if self.fusion is not None: + x_omic = kwargs['x_omic'] + h_omic = self.fc_omic(x_omic) + if self.fusion == 'bilinear': + h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze() + elif self.fusion == 'concat': + h = self.mm(torch.cat([h_path, h_omic], axis=0)) + else: + h = h_path # [256] vector + + logits = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector + Y_hat = torch.topk(logits, 1, dim = 1)[1] + hazards = torch.sigmoid(logits) + S = torch.cumprod(1 - hazards, dim=1) + + return hazards, S, Y_hat, None, None + + + +###################################### +# Deep Attention MISL Implementation # +###################################### +class MIL_Cluster_FC_surv(nn.Module): + def __init__(self, omic_input_dim=None, fusion=None, num_clusters=10, size_arg = "small", dropout=0.25, n_classes=4): + r""" + Attention MIL Implementation + + Args: + omic_input_dim (int): Dimension size of genomic features. + fusion (str): Fusion method (Choices: concat, bilinear, or None) + size_arg (str): Size of NN architecture (Choices: small or large) + dropout (float): Dropout rate + n_classes (int): Output shape of NN + """ + super(MIL_Cluster_FC_surv, self).__init__() + self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]} + self.size_dict_omic = {'small': [256, 256]} + self.num_clusters = num_clusters + self.fusion = fusion + + ### FC Cluster layers + Pooling + size = self.size_dict_path[size_arg] + phis = [] + for phenotype_i in range(num_clusters): + phi = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout), + nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)] + phis.append(nn.Sequential(*phi)) + self.phis = nn.ModuleList(phis) + self.pool1d = nn.AdaptiveAvgPool1d(1) + + ### WSI Attention MIL Construction + fc = [nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)] + attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1) + fc.append(attention_net) + self.attention_net = nn.Sequential(*fc) + self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)]) + + ### Genomic SNN Construction + Multimodal Fusion + if fusion is not None: + hidden = self.size_dict_omic['small'] + fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])] + for i, _ in enumerate(hidden[1:]): + fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25)) + self.fc_omic = nn.Sequential(*fc_omic) + + if fusion == 'concat': + self.mm = nn.Sequential(*[nn.Linear(size[2]*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()]) + elif self.fusion == 'bilinear': + self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256) + else: + self.mm = None + + self.classifier = nn.Linear(size[2], n_classes) + + + def relocate(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() >= 1: + device_ids = list(range(torch.cuda.device_count())) + self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0') + else: + self.attention_net = self.attention_net.to(device) + + if self.fusion is not None: + self.fc_omic = self.fc_omic.to(device) + self.mm = self.mm.to(device) + + self.phis = self.phis.to(device) + self.pool1d = self.pool1d.to(device) + self.rho = self.rho.to(device) + self.classifier = self.classifier.to(device) + + + def forward(self, **kwargs): + x_path = kwargs['x_path'] + cluster_id = kwargs['cluster_id'].detach().cpu().numpy() + + ### FC Cluster layers + Pooling + h_cluster = [] + for i in range(self.num_clusters): + h_cluster_i = self.phis[i](x_path[cluster_id==i]) + if h_cluster_i.shape[0] == 0: + h_cluster_i = torch.zeros((1,512)).to(torch.device('cuda')) + h_cluster.append(self.pool1d(h_cluster_i.T.unsqueeze(0)).squeeze(2)) + h_cluster = torch.stack(h_cluster, dim=1).squeeze(0) + + ### Attention MIL + A, h_path = self.attention_net(h_cluster) + A = torch.transpose(A, 1, 0) + A_raw = A + A = F.softmax(A, dim=1) + h_path = torch.mm(A, h_path) + h_path = self.rho(h_path).squeeze() + + ### Attention MIL + Genomic Fusion + if self.fusion is not None: + x_omic = kwargs['x_omic'] + h_omic = self.fc_omic(x_omic) + if self.fusion == 'bilinear': + h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze() + elif self.fusion == 'concat': + h = self.mm(torch.cat([h_path, h_omic], axis=0)) + else: + h = h_path + + logits = self.classifier(h).unsqueeze(0) + Y_hat = torch.topk(logits, 1, dim = 1)[1] + hazards = torch.sigmoid(logits) + S = torch.cumprod(1 - hazards, dim=1) + + return hazards, S, Y_hat, None, None \ No newline at end of file