Diff of /utils/utils.py [000000] .. [405115]

Switch to side-by-side view

--- a
+++ b/utils/utils.py
@@ -0,0 +1,456 @@
+import pickle
+import torch
+import numpy as np
+import torch.nn as nn
+import pdb
+
+import torch
+import numpy as np
+import torch.nn as nn
+from torchvision import transforms
+from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
+import torch.optim as optim
+import pdb
+import torch.nn.functional as F
+import math
+from itertools import islice
+import collections
+
+from torch.utils.data.dataloader import default_collate
+import torch_geometric
+from torch_geometric.data import Batch
+
+device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+class SubsetSequentialSampler(Sampler):
+    """Samples elements sequentially from a given list of indices, without replacement.
+
+    Arguments:
+        indices (sequence): a sequence of indices
+    """
+    def __init__(self, indices):
+        self.indices = indices
+
+    def __iter__(self):
+        return iter(self.indices)
+
+    def __len__(self):
+        return len(self.indices)
+
+def collate_MIL(batch):
+    img = torch.cat([item[0] for item in batch], dim = 0)
+    label = torch.LongTensor([item[1] for item in batch])
+    return [img, label]
+
+def collate_features(batch):
+    img = torch.cat([item[0] for item in batch], dim = 0)
+    coords = np.vstack([item[1] for item in batch])
+    return [img, coords]
+
+def collate_MIL_survival(batch):
+    img = torch.cat([item[0] for item in batch], dim = 0)
+    omic = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor)
+    label = torch.LongTensor([item[2] for item in batch])
+    event_time = torch.FloatTensor([item[3] for item in batch])
+    c = torch.FloatTensor([item[4] for item in batch])
+    return [img, omic, label, event_time, c]
+
+def collate_MIL_survival_cluster(batch):
+    img = torch.cat([item[0] for item in batch], dim = 0)
+    cluster_ids = torch.cat([item[1] for item in batch], dim = 0).type(torch.LongTensor)
+    omic = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor)
+    label = torch.LongTensor([item[3] for item in batch])
+    event_time = np.array([item[4] for item in batch])
+    c = torch.FloatTensor([item[5] for item in batch])
+    return [img, cluster_ids, omic, label, event_time, c]
+
+def collate_MIL_survival_sig(batch):
+    img = torch.cat([item[0] for item in batch], dim = 0)
+    omic1 = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor)
+    omic2 = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor)
+    omic3 = torch.cat([item[3] for item in batch], dim = 0).type(torch.FloatTensor)
+    omic4 = torch.cat([item[4] for item in batch], dim = 0).type(torch.FloatTensor)
+    omic5 = torch.cat([item[5] for item in batch], dim = 0).type(torch.FloatTensor)
+    omic6 = torch.cat([item[6] for item in batch], dim = 0).type(torch.FloatTensor)
+
+    label = torch.LongTensor([item[7] for item in batch])
+    event_time = np.array([item[8] for item in batch])
+    c = torch.FloatTensor([item[9] for item in batch])
+    return [img, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c]
+
+def get_simple_loader(dataset, batch_size=1):
+    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
+    loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs)
+    return loader 
+
+def get_split_loader(split_dataset, training = False, testing = False, weighted = False, mode='coattn', batch_size=1):
+    """
+        return either the validation loader or training loader 
+    """
+    if mode == 'coattn':
+        collate = collate_MIL_survival_sig
+    elif mode == 'cluster':
+        collate = collate_MIL_survival_cluster
+    else:
+        collate = collate_MIL_survival
+
+    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
+    if not testing:
+        if training:
+            if weighted:
+                weights = make_weights_for_balanced_classes_split(split_dataset)
+                loader = DataLoader(split_dataset, batch_size=batch_size, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate, **kwargs)    
+            else:
+                loader = DataLoader(split_dataset, batch_size=batch_size, sampler = RandomSampler(split_dataset), collate_fn = collate, **kwargs)
+        else:
+            loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate, **kwargs)
+    
+    else:
+        ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
+        loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs )
+
+    return loader
+
+def get_optim(model, args):
+    if args.opt == "adam":
+        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
+    elif args.opt == 'sgd':
+        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
+    else:
+        raise NotImplementedError
+    return optimizer
+
+def print_network(net):
+    num_params = 0
+    num_params_train = 0
+    print(net)
+    
+    for param in net.parameters():
+        n = param.numel()
+        num_params += n
+        if param.requires_grad:
+            num_params_train += n
+    
+    print('Total number of parameters: %d' % num_params)
+    print('Total number of trainable parameters: %d' % num_params_train)
+
+
+def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
+    seed = 7, label_frac = 1.0, custom_test_ids = None):
+    indices = np.arange(samples).astype(int)
+    
+    pdb.set_trace()
+    if custom_test_ids is not None:
+        indices = np.setdiff1d(indices, custom_test_ids)
+
+    np.random.seed(seed)
+    for i in range(n_splits):
+        all_val_ids = []
+        all_test_ids = []
+        sampled_train_ids = []
+        
+        if custom_test_ids is not None: # pre-built test split, do not need to sample
+            all_test_ids.extend(custom_test_ids)
+
+        for c in range(len(val_num)):
+            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
+            remaining_ids = possible_indices
+
+            if val_num[c] > 0:
+                val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
+                remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
+                all_val_ids.extend(val_ids)
+
+            if custom_test_ids is None and test_num[c] > 0: # sample test split
+
+                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
+                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
+                all_test_ids.extend(test_ids)
+
+            if label_frac == 1:
+                sampled_train_ids.extend(remaining_ids)
+            
+            else:
+                sample_num  = math.ceil(len(remaining_ids) * label_frac)
+                slice_ids = np.arange(sample_num)
+                sampled_train_ids.extend(remaining_ids[slice_ids])
+
+        yield sorted(sampled_train_ids), sorted(all_val_ids), sorted(all_test_ids)
+
+
+def nth(iterator, n, default=None):
+    if n is None:
+        return collections.deque(iterator, maxlen=0)
+    else:
+        return next(islice(iterator,n, None), default)
+
+def calculate_error(Y_hat, Y):
+    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
+
+    return error
+
+def make_weights_for_balanced_classes_split(dataset):
+    N = float(len(dataset))                                           
+    weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]                                                                                                     
+    weight = [0] * int(N)                                           
+    for idx in range(len(dataset)):   
+        y = dataset.getlabel(idx)                        
+        weight[idx] = weight_per_class[y]                                  
+
+    return torch.DoubleTensor(weight)
+
+def initialize_weights(module):
+    for m in module.modules():
+        if isinstance(m, nn.Linear):
+            nn.init.xavier_normal_(m.weight)
+            m.bias.data.zero_()
+        
+        elif isinstance(m, nn.BatchNorm1d):
+            nn.init.constant_(m.weight, 1)
+            nn.init.constant_(m.bias, 0)
+
+
+def dfs_freeze(model):
+    for name, child in model.named_children():
+        for param in child.parameters():
+            param.requires_grad = False
+        dfs_freeze(child)
+
+
+def dfs_unfreeze(model):
+    for name, child in model.named_children():
+        for param in child.parameters():
+            param.requires_grad = True
+        dfs_unfreeze(child)
+
+
+# divide continuous time scale into k discrete bins in total,  T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
+# Y = T_discrete is the discrete event time:
+# Y = 0 if T_cont \in (-inf, 0), Y = 1 if T_cont \in [0, a_1),  Y = 2 if T_cont in [a_1, a_2), ..., Y = k if T_cont in [a_(k-1), inf)
+# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X),  t = 0,1,2,...,k
+# S: survival function: P(Y > t | X)
+# all patients are alive from (-inf, 0) by definition, so P(Y=0) = 0
+# h(0) = 0 ---> do not need to model
+# S(0) = P(Y > 0 | X) = 1 ----> do not need to model
+'''
+Summary: neural network is hazard probability function, h(t) for t = 1,2,...,k
+corresponding Y = 1, ..., k. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
+'''
+# def neg_likelihood_loss(hazards, Y, c):
+#   batch_size = len(Y)
+#   Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
+#   c = c.view(batch_size, 1).float() #censorship status, 0 or 1
+#   S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
+#   # without padding, S(1) = S[0], h(1) = h[0]
+#   S_padded = torch.cat([torch.ones_like(c), S], 1) #S(0) = 1, all patients are alive from (-inf, 0) by definition
+#   # after padding, S(0) = S[0], S(1) = S[1], etc, h(1) = h[0]
+#   #h[y] = h(1)
+#   #S[1] = S(1)
+#   neg_l = - c * torch.log(torch.gather(S_padded, 1, Y)) - (1 - c) * (torch.log(torch.gather(S_padded, 1, Y-1)) + torch.log(hazards[:, Y-1]))
+#   neg_l = neg_l.mean()
+#   return neg_l
+
+
+# divide continuous time scale into k discrete bins in total,  T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
+# Y = T_discrete is the discrete event time:
+# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1),  Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf)
+# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X),  t = -1,0,1,2,...,k
+# S: survival function: P(Y > t | X)
+# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0
+# h(-1) = 0 ---> do not need to model
+# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model
+'''
+Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1
+corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
+'''
+def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
+    batch_size = len(Y)
+    Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
+    c = c.view(batch_size, 1).float() #censorship status, 0 or 1
+    if S is None:
+        S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
+    # without padding, S(0) = S[0], h(0) = h[0]
+    S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition
+    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
+    #h[y] = h(1)
+    #S[1] = S(1)
+    uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
+    censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
+    neg_l = censored_loss + uncensored_loss
+    loss = (1-alpha) * neg_l + alpha * uncensored_loss
+    loss = loss.mean()
+    return loss
+
+def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
+    batch_size = len(Y)
+    Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
+    c = c.view(batch_size, 1).float() #censorship status, 0 or 1
+    if S is None:
+        S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
+    # without padding, S(0) = S[0], h(0) = h[0]
+    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
+    #h[y] = h(1)
+    #S[1] = S(1)
+    S_padded = torch.cat([torch.ones_like(c), S], 1)
+    reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
+    ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
+    loss = (1-alpha) * ce_l + alpha * reg
+    loss = loss.mean()
+    return loss
+
+# def nll_loss(hazards, Y, c, S=None, alpha=0.4, eps=1e-8):
+#   batch_size = len(Y)
+#   Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
+#   c = c.view(batch_size, 1).float() #censorship status, 0 or 1
+#   if S is None:
+#       S = 1 - torch.cumsum(hazards, dim=1) # surival is cumulative product of 1 - hazards
+#   uncensored_loss = -(1 - c) * (torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
+#   censored_loss = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps))
+#   loss = censored_loss + uncensored_loss
+#   loss = loss.mean()
+#   return loss
+
+class CrossEntropySurvLoss(object):
+    def __init__(self, alpha=0.15):
+        self.alpha = alpha
+
+    def __call__(self, hazards, S, Y, c, alpha=None): 
+        if alpha is None:
+            return ce_loss(hazards, S, Y, c, alpha=self.alpha)
+        else:
+            return ce_loss(hazards, S, Y, c, alpha=alpha)
+
+# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0)
+class NLLSurvLoss_dep(object):
+    def __init__(self, alpha=0.15):
+        self.alpha = alpha
+
+    def __call__(self, hazards, S, Y, c, alpha=None):
+        if alpha is None:
+            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
+        else:
+            return nll_loss(hazards, S, Y, c, alpha=alpha)
+    # h_padded = torch.cat([torch.zeros_like(c), hazards], 1)
+    #reg = - (1 - c) * (torch.log(torch.gather(hazards, 1, Y)) + torch.gather(torch.cumsum(torch.log(1-h_padded), dim=1), 1, Y))
+
+
+class CoxSurvLoss(object):
+    def __call__(hazards, S, c, **kwargs):
+        # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
+        # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
+        current_batch_len = len(S)
+        R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
+        for i in range(current_batch_len):
+            for j in range(current_batch_len):
+                R_mat[i,j] = S[j] >= S[i]
+
+        R_mat = torch.FloatTensor(R_mat).to(device)
+        theta = hazards.reshape(-1)
+        exp_theta = torch.exp(theta)
+        loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (1-c))
+        return loss_cox
+
+def l1_reg_all(model, reg_type=None):
+    l1_reg = None
+
+    for W in model.parameters():
+        if l1_reg is None:
+            l1_reg = torch.abs(W).sum()
+        else:
+            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+    return l1_reg
+
+def l1_reg_modules(model, reg_type=None):
+    l1_reg = 0
+
+    l1_reg += l1_reg_all(model.fc_omic)
+    l1_reg += l1_reg_all(model.mm)
+
+    return l1_reg
+
+def l1_reg_omic(model, reg_type=None):
+    l1_reg = 0
+
+    if hasattr(model, 'fc_omic'):
+        l1_reg += l1_reg_all(model.fc_omic)
+    else:
+        l1_reg += l1_reg_all(model)
+
+    return l1_reg
+
+def get_custom_exp_code(args):
+    r"""
+    Updates the argparse.NameSpace with a custom experiment code.
+
+    Args:
+        - args (NameSpace)
+
+    Returns:
+        - args (NameSpace)
+    """
+    exp_code = '_'.join(args.split_dir.split('_')[:2])
+    dataset_path = 'datasets_csv'
+    param_code = ''
+
+    ### Model Type
+    if args.model_type == 'porpoise_mmf':
+      param_code += 'PorpoiseMMF'
+    elif args.model_type == 'porpoise_amil':
+      param_code += 'PorpoiseAMIL'
+    elif args.model_type == 'max_net' or args.model_type == 'snn':
+      param_code += 'SNN'
+    elif args.model_type == 'amil':
+      param_code += 'AMIL'
+    elif args.model_type == 'deepset':
+      param_code += 'DS'
+    elif args.model_type == 'mi_fcn':
+      param_code += 'MIFCN'
+    elif args.model_type == 'mcat':
+      param_code += 'MCAT'
+    else:
+      raise NotImplementedError
+
+    ### Loss Function
+    param_code += '_%s' % args.bag_loss
+    if args.bag_loss in ['nll_surv']:
+        param_code += '_a%s' % str(args.alpha_surv)
+
+    ### Learning Rate
+    if args.lr != 2e-4:
+      param_code += '_lr%s' % format(args.lr, '.0e')
+
+    ### L1-Regularization
+    if args.reg_type != 'None':
+      param_code += '_%sreg%s' % (args.reg_type, format(args.lambda_reg, '.0e'))
+
+    if args.dropinput:
+      param_code += '_drop%s' % str(int(args.dropinput*100))
+
+    param_code += '_%s' % args.which_splits.split("_")[0]
+
+    ### Batch Size
+    if args.batch_size != 1:
+      param_code += '_b%s' % str(args.batch_size)
+
+    ### Gradient Accumulation
+    if args.gc != 1:
+      param_code += '_gc%s' % str(args.gc)
+
+    ### Applying Which Features
+    if args.apply_sigfeats:
+      param_code += '_sig'
+      dataset_path += '_sig'
+    elif args.apply_mutsig:
+      param_code += '_mutsig'
+      dataset_path += '_mutsig'
+
+    ### Fusion Operation
+    if args.fusion != "None":
+      param_code += '_' + args.fusion
+
+    ### Updating
+    args.exp_code = exp_code + "_" + param_code
+    args.param_code = param_code
+    args.dataset_path = dataset_path
+
+    return args
\ No newline at end of file