import pickle
import torch
import numpy as np
import torch.nn as nn
import pdb
import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
import torch.optim as optim
import pdb
import torch.nn.functional as F
import math
from itertools import islice
import collections
from torch.utils.data.dataloader import default_collate
import torch_geometric
from torch_geometric.data import Batch
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SubsetSequentialSampler(Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
def collate_MIL(batch):
img = torch.cat([item[0] for item in batch], dim = 0)
label = torch.LongTensor([item[1] for item in batch])
return [img, label]
def collate_features(batch):
img = torch.cat([item[0] for item in batch], dim = 0)
coords = np.vstack([item[1] for item in batch])
return [img, coords]
def collate_MIL_survival(batch):
img = torch.cat([item[0] for item in batch], dim = 0)
omic = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor)
label = torch.LongTensor([item[2] for item in batch])
event_time = torch.FloatTensor([item[3] for item in batch])
c = torch.FloatTensor([item[4] for item in batch])
return [img, omic, label, event_time, c]
def collate_MIL_survival_cluster(batch):
img = torch.cat([item[0] for item in batch], dim = 0)
cluster_ids = torch.cat([item[1] for item in batch], dim = 0).type(torch.LongTensor)
omic = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor)
label = torch.LongTensor([item[3] for item in batch])
event_time = np.array([item[4] for item in batch])
c = torch.FloatTensor([item[5] for item in batch])
return [img, cluster_ids, omic, label, event_time, c]
def collate_MIL_survival_sig(batch):
img = torch.cat([item[0] for item in batch], dim = 0)
omic1 = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor)
omic2 = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor)
omic3 = torch.cat([item[3] for item in batch], dim = 0).type(torch.FloatTensor)
omic4 = torch.cat([item[4] for item in batch], dim = 0).type(torch.FloatTensor)
omic5 = torch.cat([item[5] for item in batch], dim = 0).type(torch.FloatTensor)
omic6 = torch.cat([item[6] for item in batch], dim = 0).type(torch.FloatTensor)
label = torch.LongTensor([item[7] for item in batch])
event_time = np.array([item[8] for item in batch])
c = torch.FloatTensor([item[9] for item in batch])
return [img, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c]
def get_simple_loader(dataset, batch_size=1):
kwargs = {'num_workers': 4} if device.type == "cuda" else {}
loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs)
return loader
def get_split_loader(split_dataset, training = False, testing = False, weighted = False, mode='coattn', batch_size=1):
"""
return either the validation loader or training loader
"""
if mode == 'coattn':
collate = collate_MIL_survival_sig
elif mode == 'cluster':
collate = collate_MIL_survival_cluster
else:
collate = collate_MIL_survival
kwargs = {'num_workers': 4} if device.type == "cuda" else {}
if not testing:
if training:
if weighted:
weights = make_weights_for_balanced_classes_split(split_dataset)
loader = DataLoader(split_dataset, batch_size=batch_size, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate, **kwargs)
else:
loader = DataLoader(split_dataset, batch_size=batch_size, sampler = RandomSampler(split_dataset), collate_fn = collate, **kwargs)
else:
loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate, **kwargs)
else:
ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs )
return loader
def get_optim(model, args):
if args.opt == "adam":
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
elif args.opt == 'sgd':
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
else:
raise NotImplementedError
return optimizer
def print_network(net):
num_params = 0
num_params_train = 0
print(net)
for param in net.parameters():
n = param.numel()
num_params += n
if param.requires_grad:
num_params_train += n
print('Total number of parameters: %d' % num_params)
print('Total number of trainable parameters: %d' % num_params_train)
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
seed = 7, label_frac = 1.0, custom_test_ids = None):
indices = np.arange(samples).astype(int)
pdb.set_trace()
if custom_test_ids is not None:
indices = np.setdiff1d(indices, custom_test_ids)
np.random.seed(seed)
for i in range(n_splits):
all_val_ids = []
all_test_ids = []
sampled_train_ids = []
if custom_test_ids is not None: # pre-built test split, do not need to sample
all_test_ids.extend(custom_test_ids)
for c in range(len(val_num)):
possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
remaining_ids = possible_indices
if val_num[c] > 0:
val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
all_val_ids.extend(val_ids)
if custom_test_ids is None and test_num[c] > 0: # sample test split
test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
remaining_ids = np.setdiff1d(remaining_ids, test_ids)
all_test_ids.extend(test_ids)
if label_frac == 1:
sampled_train_ids.extend(remaining_ids)
else:
sample_num = math.ceil(len(remaining_ids) * label_frac)
slice_ids = np.arange(sample_num)
sampled_train_ids.extend(remaining_ids[slice_ids])
yield sorted(sampled_train_ids), sorted(all_val_ids), sorted(all_test_ids)
def nth(iterator, n, default=None):
if n is None:
return collections.deque(iterator, maxlen=0)
else:
return next(islice(iterator,n, None), default)
def calculate_error(Y_hat, Y):
error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
return error
def make_weights_for_balanced_classes_split(dataset):
N = float(len(dataset))
weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]
weight = [0] * int(N)
for idx in range(len(dataset)):
y = dataset.getlabel(idx)
weight[idx] = weight_per_class[y]
return torch.DoubleTensor(weight)
def initialize_weights(module):
for m in module.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def dfs_freeze(model):
for name, child in model.named_children():
for param in child.parameters():
param.requires_grad = False
dfs_freeze(child)
def dfs_unfreeze(model):
for name, child in model.named_children():
for param in child.parameters():
param.requires_grad = True
dfs_unfreeze(child)
# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
# Y = T_discrete is the discrete event time:
# Y = 0 if T_cont \in (-inf, 0), Y = 1 if T_cont \in [0, a_1), Y = 2 if T_cont in [a_1, a_2), ..., Y = k if T_cont in [a_(k-1), inf)
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = 0,1,2,...,k
# S: survival function: P(Y > t | X)
# all patients are alive from (-inf, 0) by definition, so P(Y=0) = 0
# h(0) = 0 ---> do not need to model
# S(0) = P(Y > 0 | X) = 1 ----> do not need to model
'''
Summary: neural network is hazard probability function, h(t) for t = 1,2,...,k
corresponding Y = 1, ..., k. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
'''
# def neg_likelihood_loss(hazards, Y, c):
# batch_size = len(Y)
# Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
# c = c.view(batch_size, 1).float() #censorship status, 0 or 1
# S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
# # without padding, S(1) = S[0], h(1) = h[0]
# S_padded = torch.cat([torch.ones_like(c), S], 1) #S(0) = 1, all patients are alive from (-inf, 0) by definition
# # after padding, S(0) = S[0], S(1) = S[1], etc, h(1) = h[0]
# #h[y] = h(1)
# #S[1] = S(1)
# neg_l = - c * torch.log(torch.gather(S_padded, 1, Y)) - (1 - c) * (torch.log(torch.gather(S_padded, 1, Y-1)) + torch.log(hazards[:, Y-1]))
# neg_l = neg_l.mean()
# return neg_l
# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
# Y = T_discrete is the discrete event time:
# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1), Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf)
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = -1,0,1,2,...,k
# S: survival function: P(Y > t | X)
# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0
# h(-1) = 0 ---> do not need to model
# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model
'''
Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1
corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
'''
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
batch_size = len(Y)
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
c = c.view(batch_size, 1).float() #censorship status, 0 or 1
if S is None:
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
# without padding, S(0) = S[0], h(0) = h[0]
S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
#h[y] = h(1)
#S[1] = S(1)
uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
neg_l = censored_loss + uncensored_loss
loss = (1-alpha) * neg_l + alpha * uncensored_loss
loss = loss.mean()
return loss
def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
batch_size = len(Y)
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
c = c.view(batch_size, 1).float() #censorship status, 0 or 1
if S is None:
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
# without padding, S(0) = S[0], h(0) = h[0]
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
#h[y] = h(1)
#S[1] = S(1)
S_padded = torch.cat([torch.ones_like(c), S], 1)
reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
loss = (1-alpha) * ce_l + alpha * reg
loss = loss.mean()
return loss
# def nll_loss(hazards, Y, c, S=None, alpha=0.4, eps=1e-8):
# batch_size = len(Y)
# Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
# c = c.view(batch_size, 1).float() #censorship status, 0 or 1
# if S is None:
# S = 1 - torch.cumsum(hazards, dim=1) # surival is cumulative product of 1 - hazards
# uncensored_loss = -(1 - c) * (torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
# censored_loss = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps))
# loss = censored_loss + uncensored_loss
# loss = loss.mean()
# return loss
class CrossEntropySurvLoss(object):
def __init__(self, alpha=0.15):
self.alpha = alpha
def __call__(self, hazards, S, Y, c, alpha=None):
if alpha is None:
return ce_loss(hazards, S, Y, c, alpha=self.alpha)
else:
return ce_loss(hazards, S, Y, c, alpha=alpha)
# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0)
class NLLSurvLoss_dep(object):
def __init__(self, alpha=0.15):
self.alpha = alpha
def __call__(self, hazards, S, Y, c, alpha=None):
if alpha is None:
return nll_loss(hazards, S, Y, c, alpha=self.alpha)
else:
return nll_loss(hazards, S, Y, c, alpha=alpha)
# h_padded = torch.cat([torch.zeros_like(c), hazards], 1)
#reg = - (1 - c) * (torch.log(torch.gather(hazards, 1, Y)) + torch.gather(torch.cumsum(torch.log(1-h_padded), dim=1), 1, Y))
class CoxSurvLoss(object):
def __call__(hazards, S, c, **kwargs):
# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
current_batch_len = len(S)
R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
for i in range(current_batch_len):
for j in range(current_batch_len):
R_mat[i,j] = S[j] >= S[i]
R_mat = torch.FloatTensor(R_mat).to(device)
theta = hazards.reshape(-1)
exp_theta = torch.exp(theta)
loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (1-c))
return loss_cox
def l1_reg_all(model, reg_type=None):
l1_reg = None
for W in model.parameters():
if l1_reg is None:
l1_reg = torch.abs(W).sum()
else:
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
return l1_reg
def l1_reg_modules(model, reg_type=None):
l1_reg = 0
l1_reg += l1_reg_all(model.fc_omic)
l1_reg += l1_reg_all(model.mm)
return l1_reg
def l1_reg_omic(model, reg_type=None):
l1_reg = 0
if hasattr(model, 'fc_omic'):
l1_reg += l1_reg_all(model.fc_omic)
else:
l1_reg += l1_reg_all(model)
return l1_reg
def get_custom_exp_code(args):
r"""
Updates the argparse.NameSpace with a custom experiment code.
Args:
- args (NameSpace)
Returns:
- args (NameSpace)
"""
exp_code = '_'.join(args.split_dir.split('_')[:2])
dataset_path = 'datasets_csv'
param_code = ''
### Model Type
if args.model_type == 'porpoise_mmf':
param_code += 'PorpoiseMMF'
elif args.model_type == 'porpoise_amil':
param_code += 'PorpoiseAMIL'
elif args.model_type == 'max_net' or args.model_type == 'snn':
param_code += 'SNN'
elif args.model_type == 'amil':
param_code += 'AMIL'
elif args.model_type == 'deepset':
param_code += 'DS'
elif args.model_type == 'mi_fcn':
param_code += 'MIFCN'
elif args.model_type == 'mcat':
param_code += 'MCAT'
else:
raise NotImplementedError
### Loss Function
param_code += '_%s' % args.bag_loss
if args.bag_loss in ['nll_surv']:
param_code += '_a%s' % str(args.alpha_surv)
### Learning Rate
if args.lr != 2e-4:
param_code += '_lr%s' % format(args.lr, '.0e')
### L1-Regularization
if args.reg_type != 'None':
param_code += '_%sreg%s' % (args.reg_type, format(args.lambda_reg, '.0e'))
if args.dropinput:
param_code += '_drop%s' % str(int(args.dropinput*100))
param_code += '_%s' % args.which_splits.split("_")[0]
### Batch Size
if args.batch_size != 1:
param_code += '_b%s' % str(args.batch_size)
### Gradient Accumulation
if args.gc != 1:
param_code += '_gc%s' % str(args.gc)
### Applying Which Features
if args.apply_sigfeats:
param_code += '_sig'
dataset_path += '_sig'
elif args.apply_mutsig:
param_code += '_mutsig'
dataset_path += '_mutsig'
### Fusion Operation
if args.fusion != "None":
param_code += '_' + args.fusion
### Updating
args.exp_code = exp_code + "_" + param_code
args.param_code = param_code
args.dataset_path = dataset_path
return args