--- a
+++ b/utils.py
@@ -0,0 +1,892 @@
+# Base / Native
+import math
+import os
+import pickle
+import re
+import warnings
+warnings.filterwarnings('ignore')
+
+# Numerical / Array
+import lifelines
+from lifelines.utils import concordance_index
+from lifelines import CoxPHFitter
+from lifelines.datasets import load_regression_dataset
+from lifelines.utils import k_fold_cross_validation
+from lifelines.statistics import logrank_test
+from imblearn.over_sampling import RandomOverSampler
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import matplotlib.font_manager as font_manager
+import numpy as np
+import pandas as pd
+from PIL import Image
+import pylab
+import scipy
+import seaborn as sns
+from sklearn import preprocessing
+from sklearn.model_selection import train_test_split, KFold
+from sklearn.metrics import average_precision_score, auc, f1_score, roc_curve, roc_auc_score
+from sklearn.preprocessing import LabelBinarizer
+
+from scipy import interp
+mpl.rcParams['axes.linewidth'] = 3 #set the value globally
+
+# Torch
+import torch
+import torch.nn as nn
+from torch.nn import init, Parameter
+from torch.utils.data._utils.collate import *
+from torch.utils.data.dataloader import default_collate
+import torch_geometric
+from torch_geometric.data import Batch
+
+
+
+################
+# Regularization
+################
+def regularize_weights(model, reg_type=None):
+    l1_reg = None
+
+    for W in model.parameters():
+        if l1_reg is None:
+            l1_reg = torch.abs(W).sum()
+        else:
+            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+    return l1_reg
+
+
+def regularize_path_weights(model, reg_type=None):
+    l1_reg = None
+    
+    for W in model.module.classifier.parameters():
+        if l1_reg is None:
+            l1_reg = torch.abs(W).sum()
+        else:
+            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    for W in model.module.linear.parameters():
+        if l1_reg is None:
+            l1_reg = torch.abs(W).sum()
+        else:
+            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    return l1_reg
+
+
+def regularize_MM_weights(model, reg_type=None):
+    l1_reg = None
+
+    if model.module.__hasattr__('omic_net'):
+        for W in model.module.omic_net.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_h_path'):
+        for W in model.module.linear_h_path.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_h_omic'):
+        for W in model.module.linear_h_omic.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_h_grph'):
+        for W in model.module.linear_h_grph.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_z_path'):
+        for W in model.module.linear_z_path.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_z_omic'):
+        for W in model.module.linear_z_omic.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_z_grph'):
+        for W in model.module.linear_z_grph.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_o_path'):
+        for W in model.module.linear_o_path.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_o_omic'):
+        for W in model.module.linear_o_omic.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('linear_o_grph'):
+        for W in model.module.linear_o_grph.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('encoder1'):
+        for W in model.module.encoder1.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('encoder2'):
+        for W in model.module.encoder2.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    if model.module.__hasattr__('classifier'):
+        for W in model.module.classifier.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+        
+    return l1_reg
+
+
+def regularize_MM_omic(model, reg_type=None):
+    l1_reg = None
+
+    if model.module.__hasattr__('omic_net'):
+        for W in model.module.omic_net.parameters():
+            if l1_reg is None:
+                l1_reg = torch.abs(W).sum()
+            else:
+                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
+
+    return l1_reg
+
+
+
+################
+# Network Initialization
+################
+def init_weights(net, init_type='orthogonal', init_gain=0.02):
+    """Initialize network weights.
+
+    Parameters:
+        net (network)   -- network to be initialized
+        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
+
+    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+    work better for some applications. Feel free to try yourself.
+    """
+    def init_func(m):  # define the initialization function
+        classname = m.__class__.__name__
+        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+            if init_type == 'normal':
+                init.normal_(m.weight.data, 0.0, init_gain)
+            elif init_type == 'xavier':
+                init.xavier_normal_(m.weight.data, gain=init_gain)
+            elif init_type == 'kaiming':
+                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+            elif init_type == 'orthogonal':
+                init.orthogonal_(m.weight.data, gain=init_gain)
+            else:
+                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+            if hasattr(m, 'bias') and m.bias is not None:
+                init.constant_(m.bias.data, 0.0)
+        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+            init.normal_(m.weight.data, 1.0, init_gain)
+            init.constant_(m.bias.data, 0.0)
+
+    print('initialize network with %s' % init_type)
+    net.apply(init_func)  # apply the initialization function <init_func>
+
+
+def init_max_weights(module):
+    for m in module.modules():
+        if type(m) == nn.Linear:
+            stdv = 1. / math.sqrt(m.weight.size(1))
+            m.weight.data.normal_(0, stdv)
+            m.bias.data.zero_()
+
+
+def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
+    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
+    Parameters:
+        net (network)      -- the network to be initialized
+        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+        gain (float)       -- scaling factor for normal, xavier and orthogonal.
+        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+    Return an initialized network.
+    """
+    if len(gpu_ids) > 0:
+        assert(torch.cuda.is_available())
+        net.to(gpu_ids[0])
+        net = torch.nn.DataParallel(net, gpu_ids)           # multi-GPUs
+
+    if init_type != 'max' and init_type != 'none':
+        print("Init Type:", init_type)
+        init_weights(net, init_type, init_gain=init_gain)
+    elif init_type == 'none':
+        print("Init Type: Not initializing networks.")
+    elif init_type == 'max':
+        print("Init Type: Self-Normalizing Weights")
+    return net
+
+
+
+################
+# Freeze / Unfreeze
+################
+def unfreeze_unimodal(opt, model, epoch):
+    if opt.mode == 'graphomic':
+        if epoch == 5:
+            dfs_unfreeze(model.module.omic_net)
+            print("Unfreezing Omic")
+        if epoch == 5:
+            dfs_unfreeze(model.module.grph_net)
+            print("Unfreezing Graph")
+    elif opt.mode == 'pathomic':
+        if epoch == 5:
+            dfs_unfreeze(model.module.omic_net)
+            print("Unfreezing Omic")
+    elif opt.mode == 'pathgraph':
+        if epoch == 5:
+            dfs_unfreeze(model.module.grph_net)
+            print("Unfreezing Graph")
+    elif opt.mode == "pathgraphomic":
+        if epoch == 5:
+            dfs_unfreeze(model.module.omic_net)
+            print("Unfreezing Omic")
+        if epoch == 5:
+            dfs_unfreeze(model.module.grph_net)
+            print("Unfreezing Graph")
+    elif opt.mode == "omicomic":
+        if epoch == 5:
+            dfs_unfreeze(model.module.omic_net)
+            print("Unfreezing Omic")
+    elif opt.mode == "graphgraph":
+        if epoch == 5:
+            dfs_unfreeze(model.module.grph_net)
+            print("Unfreezing Graph")
+
+
+def dfs_freeze(model):
+    for name, child in model.named_children():
+        for param in child.parameters():
+            param.requires_grad = False
+        dfs_freeze(child)
+
+
+def dfs_unfreeze(model):
+    for name, child in model.named_children():
+        for param in child.parameters():
+            param.requires_grad = True
+        dfs_unfreeze(child)
+
+
+def print_if_frozen(module):
+    for idx, child in enumerate(module.children()):
+        for param in child.parameters():
+            if param.requires_grad == True:
+                print("Learnable!!! %d:" % idx, child)
+            else:
+                print("Still Frozen %d:" % idx, child)
+
+
+def unfreeze_vgg_features(model, epoch):
+    epoch_schedule = {30:45}
+    unfreeze_index = epoch_schedule[epoch]
+    for idx, child in enumerate(model.features.children()):
+        if idx > unfreeze_index:
+            print("Unfreezing %d:" %idx, child)
+            for param in child.parameters(): 
+                param.requires_grad = True
+        else:
+            print("Still Frozen %d:" %idx, child)
+            continue
+
+
+
+################
+# Collate Utils
+################
+def mixed_collate(batch):
+    elem = batch[0]
+    elem_type = type(elem)    
+    transposed = zip(*batch)
+    return [Batch.from_data_list(samples, []) if type(samples[0]) is torch_geometric.data.data.Data else default_collate(samples) for samples in transposed]
+
+
+
+################
+# Survival Utils
+################
+def CoxLoss(survtime, censor, hazard_pred, device):
+    # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
+    # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
+    current_batch_len = len(survtime)
+    R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
+    for i in range(current_batch_len):
+        for j in range(current_batch_len):
+            R_mat[i,j] = survtime[j] >= survtime[i]
+
+    R_mat = torch.FloatTensor(R_mat).to(device)
+    theta = hazard_pred.reshape(-1)
+    exp_theta = torch.exp(theta)
+    loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor)
+    return loss_cox
+
+
+def accuracy(output, labels):
+    preds = output.max(1)[1].type_as(labels)
+    correct = preds.eq(labels).double()
+    correct = correct.sum()
+    return correct / len(labels)
+
+
+def accuracy_cox(hazardsdata, labels):
+    # This accuracy is based on estimated survival events against true survival events
+    median = np.median(hazardsdata)
+    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
+    hazards_dichotomize[hazardsdata > median] = 1
+    correct = np.sum(hazards_dichotomize == labels)
+    return correct / len(labels)
+
+
+def cox_log_rank(hazardsdata, labels, survtime_all):
+    median = np.median(hazardsdata)
+    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
+    hazards_dichotomize[hazardsdata > median] = 1
+    idx = hazards_dichotomize == 0
+    T1 = survtime_all[idx]
+    T2 = survtime_all[~idx]
+    E1 = labels[idx]
+    E2 = labels[~idx]
+    results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2)
+    pvalue_pred = results.p_value
+    return(pvalue_pred)
+
+
+def CIndex(hazards, labels, survtime_all):
+    concord = 0.
+    total = 0.
+    N_test = labels.shape[0]
+    for i in range(N_test):
+        if labels[i] == 1:
+            for j in range(N_test):
+                if survtime_all[j] > survtime_all[i]:
+                    total += 1
+                    if hazards[j] < hazards[i]: concord += 1
+                    elif hazards[j] < hazards[i]: concord += 0.5
+
+    return(concord/total)
+
+
+def CIndex_lifeline(hazards, labels, survtime_all):
+    return(concordance_index(survtime_all, -hazards, labels))
+
+
+
+################
+# Data Utils
+################
+def addHistomolecularSubtype(data):
+    """
+    Molecular Subtype: IDHwt == 0, IDHmut-non-codel == 1, IDHmut-codel == 2
+    Histology Subtype: astrocytoma == 0, oligoastrocytoma == 1, oligodendroglioma == 2, glioblastoma == 3
+    """
+    subtyped_data = data.copy()
+    subtyped_data.insert(loc=0, column='Histomolecular subtype', value=np.ones(len(data)))
+    idhwt_ATC = np.logical_and(data['Molecular subtype'] == 0, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
+    subtyped_data.loc[idhwt_ATC, 'Histomolecular subtype'] = 'idhwt_ATC'
+    
+    idhmut_ATC = np.logical_and(data['Molecular subtype'] == 1, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
+    subtyped_data.loc[idhmut_ATC, 'Histomolecular subtype'] = 'idhmut_ATC'
+    
+    ODG = np.logical_and(data['Molecular subtype'] == 2, data['Histology'] == 2)
+    subtyped_data.loc[ODG, 'Histomolecular subtype'] = 'ODG'
+    return subtyped_data
+
+
+def changeHistomolecularSubtype(data):
+    """
+    Molecular Subtype: IDHwt == 0, IDHmut-non-codel == 1, IDHmut-codel == 2
+    Histology Subtype: astrocytoma == 0, oligoastrocytoma == 1, oligodendroglioma == 2, glioblastoma == 3
+    """
+    data = data.drop(['Histomolecular subtype'], axis=1)
+    subtyped_data = data.copy()
+    subtyped_data.insert(loc=0, column='Histomolecular subtype', value=np.ones(len(data)))
+    idhwt_ATC = np.logical_and(data['Molecular subtype'] == 0, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
+    subtyped_data.loc[idhwt_ATC, 'Histomolecular subtype'] = 'idhwt_ATC'
+    
+    idhmut_ATC = np.logical_and(data['Molecular subtype'] == 1, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
+    subtyped_data.loc[idhmut_ATC, 'Histomolecular subtype'] = 'idhmut_ATC'
+    
+    ODG = np.logical_and(data['Molecular subtype'] == 2, data['Histology'] == 2)
+    subtyped_data.loc[ODG, 'Histomolecular subtype'] = 'ODG'
+    return subtyped_data
+
+
+def getCleanAllDataset(dataroot='./data/TCGA_GBMLGG/', ignore_missing_moltype=False, ignore_missing_histype=False, use_rnaseq=False):
+    ### 1. Joining all_datasets.csv with grade data. Looks at columns with misisng samples
+    metadata = ['Histology', 'Grade', 'Molecular subtype', 'TCGA ID', 'censored', 'Survival months']
+    all_dataset = pd.read_csv(os.path.join(dataroot, 'all_dataset.csv')).drop('indexes', axis=1)
+    all_dataset.index = all_dataset['TCGA ID']
+
+    all_grade = pd.read_csv(os.path.join(dataroot, 'grade_data.csv'))
+    all_grade['Histology'] = all_grade['Histology'].str.replace('astrocytoma (glioblastoma)', 'glioblastoma', regex=False)
+    all_grade.index = all_grade['TCGA ID']
+    assert pd.Series(all_dataset.index).equals(pd.Series(sorted(all_grade.index)))
+
+    all_dataset = all_dataset.join(all_grade[['Histology', 'Grade', 'Molecular subtype']], how='inner')
+    cols = all_dataset.columns.tolist()
+    cols = cols[-3:] + cols[:-3]
+    all_dataset = all_dataset[cols]
+
+    if use_rnaseq:
+        gbm = pd.read_csv(os.path.join(dataroot, 'mRNA_Expression_z-Scores_RNA_Seq_RSEM.txt'), sep='\t', skiprows=1, index_col=0)
+        lgg = pd.read_csv(os.path.join(dataroot, 'mRNA_Expression_Zscores_RSEM.txt'), sep='\t', skiprows=1, index_col=0)
+        gbm = gbm[gbm.columns[~gbm.isnull().all()]]
+        lgg = lgg[lgg.columns[~lgg.isnull().all()]]
+        glioma_RNAseq = gbm.join(lgg, how='inner').T
+        glioma_RNAseq = glioma_RNAseq.dropna(axis=1)
+        glioma_RNAseq.columns = [gene+'_rnaseq' for gene in glioma_RNAseq.columns]
+        glioma_RNAseq.index = [patname[:12] for patname in glioma_RNAseq.index]
+        glioma_RNAseq = glioma_RNAseq.iloc[~glioma_RNAseq.index.duplicated()]
+        glioma_RNAseq.index.name = 'TCGA ID'
+        all_dataset = all_dataset.join(glioma_RNAseq, how='inner')
+
+    pat_missing_moltype = all_dataset[all_dataset['Molecular subtype'].isna()].index
+    pat_missing_idh = all_dataset[all_dataset['idh mutation'].isna()].index
+    pat_missing_1p19q = all_dataset[all_dataset['codeletion'].isna()].index
+    print("# Missing Molecular Subtype:", len(pat_missing_moltype))
+    print("# Missing IDH Mutation:", len(pat_missing_idh))
+    print("# Missing 1p19q Codeletion:", len(pat_missing_1p19q))
+    assert pat_missing_moltype.equals(pat_missing_idh)
+    assert pat_missing_moltype.equals(pat_missing_1p19q)
+    pat_missing_grade =  all_dataset[all_dataset['Grade'].isna()].index
+    pat_missing_histype = all_dataset[all_dataset['Histology'].isna()].index
+    print("# Missing Histological Subtype:", len(pat_missing_histype))
+    print("# Missing Grade:", len(pat_missing_grade))
+    assert pat_missing_histype.equals(pat_missing_grade)
+
+    ### 2. Impute Missing Genomic Data: Removes patients with missing molecular subtype / idh mutation / 1p19q. Else imputes with median value of each column. Fills missing Molecular subtype with "Missing"
+    if ignore_missing_moltype: 
+        all_dataset = all_dataset[all_dataset['Molecular subtype'].isna() == False]
+    for col in all_dataset.drop(metadata, axis=1).columns:
+        all_dataset['Molecular subtype'] = all_dataset['Molecular subtype'].fillna('Missing')
+        all_dataset[col] = all_dataset[col].fillna(all_dataset[col].median())
+
+    ### 3. Impute Missing Histological Data: Removes patients with missing histological subtype / grade. Else imputes with "missing" / grade -1
+    if ignore_missing_histype: 
+        all_dataset = all_dataset[all_dataset['Histology'].isna() == False]
+    else:
+        all_dataset['Grade'] = all_dataset['Grade'].fillna(1)
+        all_dataset['Histology'] = all_dataset['Histology'].fillna('Missing')
+    all_dataset['Grade'] = all_dataset['Grade'] - 2
+
+    ### 4. Adds Histomolecular subtype
+    ms2int = {'Missing':-1, 'IDHwt':0, 'IDHmut-non-codel':1, 'IDHmut-codel':2}
+    all_dataset[['Molecular subtype']] = all_dataset[['Molecular subtype']].applymap(lambda s: ms2int.get(s) if s in ms2int else s)
+    hs2int = {'Missing':-1, 'astrocytoma':0, 'oligoastrocytoma':1, 'oligodendroglioma':2, 'glioblastoma':3}
+    all_dataset[['Histology']] = all_dataset[['Histology']].applymap(lambda s: hs2int.get(s) if s in hs2int else s)
+    all_dataset = addHistomolecularSubtype(all_dataset)
+    metadata.extend(['Histomolecular subtype'])
+    all_dataset['censored'] = 1 - all_dataset['censored']
+    return metadata, all_dataset
+
+
+
+################
+# Analysis Utils
+################
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def hazard2grade(hazard, p):
+    if hazard < p[0]:
+        return 0
+    elif hazard < p[1]:
+        return 1
+    return 2
+
+
+def p(n):
+    def percentile_(x):
+        return np.percentile(x, n)
+    percentile_.__name__ = 'p%s' % n
+    return percentile_
+
+
+def natural_sort(l): 
+    convert = lambda text: int(text) if text.isdigit() else text.lower() 
+    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
+    return sorted(l, key = alphanum_key)
+
+
+def CI_pm(data, confidence=0.95):
+    a = 1.0 * np.array(data)
+    n = len(a)
+    m, se = np.mean(a), scipy.stats.sem(a)
+    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
+    return str("{0:.4f} ± ".format(m) + "{0:.3f}".format(h))
+
+
+def CI_interval(data, confidence=0.95):
+    a = 1.0 * np.array(data)
+    n = len(a)
+    m, se = np.mean(a), scipy.stats.sem(a)
+    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
+    return str("{0:.3f}, ".format(m-h) + "{0:.3f}".format(m+h))
+
+
+def poolSurvTestPD(ckpt_name='./checkpoints/TCGA_GBMLGG/surv_15_rnaseq/', model='pathgraphomic_fusion', split='test', zscore=False, agg_type='Hazard_mean'):
+    all_dataset_regstrd_pooled = []    
+    ignore_missing_moltype = 1 if 'omic' in model else 0
+    ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
+    use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if ((('path' in model) or ('graph' in model)) and ('cox' not in model)) else ('_', 'all_st', 0)
+    use_rnaseq = '_rnaseq' if ('rnaseq' in ckpt_name and 'path' != model and 'pathpath' not in model and 'graph' != model and 'graphgraph' not in model) else ''
+
+    for k in range(1,16):
+        pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
+        
+        if 'cox' not in model:
+            surv_all = pd.DataFrame(np.stack(np.delete(np.array(pred), 3))).T
+            surv_all.columns = ['Hazard', 'Survival months', 'censored', 'Grade']
+            data_cv = pickle.load(open('./data/TCGA_GBMLGG/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (roi_dir, ignore_missing_moltype, ignore_missing_histype, use_vgg_features, use_rnaseq), 'rb'))
+            data_cv_splits = data_cv['cv_splits']
+            data_cv_split_k = data_cv_splits[k]
+            assert np.all(data_cv_split_k[split]['t'] == pred[1]) # Data is correctly registered
+            all_dataset = data_cv['data_pd'].drop('TCGA ID', axis=1)
+            all_dataset_regstrd = all_dataset.loc[data_cv_split_k[split]['x_patname']] # Subset of "all_datasets" (metadata) that is registered with "pred" (predictions)
+            assert np.all(np.array(all_dataset_regstrd['Survival months']) == pred[1])
+            assert np.all(np.array(all_dataset_regstrd['censored']) == pred[2])
+            assert np.all(np.array(all_dataset_regstrd['Grade']) == pred[4])
+            all_dataset_regstrd.insert(loc=0, column='Hazard', value = np.array(surv_all['Hazard']))
+            all_dataset_regstrd.index.name = 'TCGA ID'
+            hazard_agg = all_dataset_regstrd.groupby('TCGA ID').agg({'Hazard': ['mean', 'median', max, p(0.25), p(0.75)]})
+            hazard_agg.columns = ["_".join(x) for x in hazard_agg.columns.ravel()]
+            hazard_agg = hazard_agg[[agg_type]]
+            hazard_agg.columns = ['Hazard']
+            pred = hazard_agg.join(all_dataset, how='inner')
+
+        if zscore: pred['Hazard'] = scipy.stats.zscore(np.array(pred['Hazard']))
+        all_dataset_regstrd_pooled.append(pred)
+
+    all_dataset_regstrd_pooled = pd.concat(all_dataset_regstrd_pooled)
+    all_dataset_regstrd_pooled = changeHistomolecularSubtype(all_dataset_regstrd_pooled)
+    return all_dataset_regstrd_pooled
+
+
+def getAggHazardCV(ckpt_name='./checkpoints/TCGA_GBMLGG/surv_15_rnaseq/', model='pathgraphomic_fusion', split='test', agg_type='Hazard_mean'):
+    result = []
+    
+    ignore_missing_moltype = 1 if 'omic' in model else 0
+    ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
+    use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0)
+    use_rnaseq = '_rnaseq' if ('rnaseq' in ckpt_name and 'path' != model and 'pathpath' not in model and 'graph' != model and 'graphgraph' not in model) else ''
+
+    for k in range(1,16):
+        pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
+        surv_all = pd.DataFrame(np.stack(np.delete(np.array(pred), 3))).T
+        surv_all.columns = ['Hazard', 'Survival months', 'censored', 'Grade']
+        data_cv = pickle.load(open('./data/TCGA_GBMLGG/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (roi_dir, ignore_missing_moltype, ignore_missing_histype, use_vgg_features, use_rnaseq), 'rb'))
+        data_cv_splits = data_cv['cv_splits']
+        data_cv_split_k = data_cv_splits[k]
+        assert np.all(data_cv_split_k[split]['t'] == pred[1]) # Data is correctly registered
+        all_dataset = data_cv['data_pd'].drop('TCGA ID', axis=1)
+        all_dataset_regstrd = all_dataset.loc[data_cv_split_k[split]['x_patname']] # Subset of "all_datasets" (metadata) that is registered with "pred" (predictions)
+        assert np.all(np.array(all_dataset_regstrd['Survival months']) == pred[1])
+        assert np.all(np.array(all_dataset_regstrd['censored']) == pred[2])
+        assert np.all(np.array(all_dataset_regstrd['Grade']) == pred[4])
+        all_dataset_regstrd.insert(loc=0, column='Hazard', value = np.array(surv_all['Hazard']))
+        all_dataset_regstrd.index.name = 'TCGA ID'
+        hazard_agg = all_dataset_regstrd.groupby('TCGA ID').agg({'Hazard': ['mean', max, p(0.75)]})
+        hazard_agg.columns = ["_".join(x) for x in hazard_agg.columns.ravel()]
+        hazard_agg = hazard_agg[[agg_type]]
+        hazard_agg.columns = ['Hazard']
+        all_dataset_hazard = hazard_agg.join(all_dataset, how='inner')
+        cin = CIndex_lifeline(all_dataset_hazard['Hazard'], all_dataset_hazard['censored'], all_dataset_hazard['Survival months'])
+        result.append(cin)
+        
+    return result
+
+
+def calcGradMetrics(ckpt_name='./checkpoints/grad_15/', model='pathgraphomic_fusion', split='test', avg='micro'):
+    auc_all = []
+    ap_all = []
+    f1_all = []
+    f1_gradeIV_all = []
+    
+    ignore_missing_moltype = 1 if 'omic' in model else 0
+    ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
+    use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0)
+    
+    for k in range(1,16):
+        pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
+        grade_pred, grade = np.array(pred[3]), np.array(pred[4])
+        enc = LabelBinarizer()
+        enc.fit(grade)
+        grade_oh = enc.transform(grade)
+        rocauc = roc_auc_score(grade_oh, grade_pred, avg)
+        ap = average_precision_score(grade_oh, grade_pred, average=avg)
+        f1 = f1_score(grade_pred.argmax(axis=1), grade, average=avg)
+        f1_gradeIV = f1_score(grade_pred.argmax(axis=1), grade, average=None)[2]
+        
+        auc_all.append(rocauc)
+        ap_all.append(ap)
+        f1_all.append(f1)
+        f1_gradeIV_all.append(f1_gradeIV)
+        
+    return np.array([CI_pm(auc_all), CI_pm(ap_all), CI_pm(f1_all), CI_pm(f1_gradeIV_all)])
+
+
+
+################
+# Plot Utils
+################
+def makeKaplanMeierPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='omic', split='test', zscore=False, agg_type='Hazard_mean'):
+    def hazard2KMCurve(data, subtype):
+        p = np.percentile(data['Hazard'], [33, 66])
+        if p[0] == p[1]: p[0] = 2.99997
+        data.insert(0, 'grade_pred', [hazard2grade(hazard, p) for hazard in data['Hazard']])
+        kmf_pred = lifelines.KaplanMeierFitter()
+        kmf_gt = lifelines.KaplanMeierFitter()
+
+        def get_name(model):
+            mode2name = {'pathgraphomic':'Pathomic F.', 'pathomic':'Pathomic F.', 'graphomic':'Pathomic F.', 'path':'Histology CNN', 'graph':'Histology GCN', 'omic':'Genomic SNN'}
+            for mode in mode2name.keys():
+                if mode in model: return mode2name[mode]
+            return 'N/A'
+
+        fig = plt.figure(figsize=(10, 10), dpi=600)
+        ax = plt.subplot()
+        censor_style = {'ms': 20, 'marker': '+'}
+        
+        temp = data[data['Grade']==0]
+        kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade II")
+        kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='g', linewidth=3, ls='--', markerfacecolor='black', censor_styles=censor_style)
+        temp = data[data['grade_pred']==0]
+        kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (Low)" % get_name(model))
+        kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='g', linewidth=4, ls='-', markerfacecolor='black', censor_styles=censor_style)
+
+        temp = data[data['Grade']==1]
+        kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade III")
+        kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='b', linewidth=3, ls='--', censor_styles=censor_style)
+        temp = data[data['grade_pred']==1]
+        kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (Mid)" % get_name(model))
+        kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='b', linewidth=4, ls='-', censor_styles=censor_style)
+
+        if subtype != 'ODG':    
+            temp = data[data['Grade']==2]
+            kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade IV")
+            kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='r', linewidth=3, ls='--', censor_styles=censor_style)
+            temp = data[data['grade_pred']==2]
+            kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (High)" % get_name(model))
+            kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='r', linewidth=4, ls='-', censor_styles=censor_style)
+
+        ax.set_xlabel('')
+        ax.set_ylim(0, 1)
+        ax.set_yticks(np.arange(0, 1.001, 0.5))
+
+        ax.tick_params(axis='both', which='major', labelsize=40)    
+        plt.legend(fontsize=32, prop=font_manager.FontProperties(family='Arial', style='normal', size=32))
+        if subtype != 'idhwt_ATC': ax.get_legend().remove()
+        return fig
+    
+    data = poolSurvTestPD(ckpt_name, model, split, zscore, agg_type)
+    for subtype in ['idhwt_ATC', 'idhmut_ATC', 'ODG']:
+        fig = hazard2KMCurve(data[data['Histomolecular subtype'] == subtype], subtype)
+        fig.savefig(ckpt_name+'/%s_KM_%s.png' % (model, subtype))
+        
+    fig = hazard2KMCurve(data, 'all')
+    fig.savefig(ckpt_name+'/%s_KM_%s.png' % (model, 'all'))
+
+
+def makeHazardSwarmPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='path', split='test', zscore=True, agg_type='Hazard_mean'):
+    mpl.rcParams['font.family'] = "arial"
+    data = poolSurvTestPD(ckpt_name=ckpt_name, model=model, split=split, zscore=zscore, agg_type=agg_type)
+    data = data[data['Grade'] != -1]
+    data = data[data['Histomolecular subtype'] != -1]
+    data['Grade'] = data['Grade'].astype(int).astype(str)
+    data['Grade'] = data['Grade'].str.replace('0', 'Grade II', regex=False)
+    data['Grade'] = data['Grade'].str.replace('1', 'Grade III', regex=False)
+    data['Grade'] = data['Grade'].str.replace('2', 'Grade IV', regex=False)
+    data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('idhwt_ATC', 'IDH-wt \n astryocytoma', regex=False)
+    data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('idhmut_ATC', 'IDH-mut \n astrocytoma', regex=False)
+    data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('ODG', 'Oligodendroglioma', regex=False)
+
+    fig, ax = plt.subplots(dpi=600)
+    ax.set_ylim([-2, 2.5]) # plt.ylim(-2, 2)
+    ax.spines['right'].set_visible(False)
+    ax.spines['top'].set_visible(False)
+    ax.set_yticks(np.arange(-2, 2.001, 1))
+    
+    sns.swarmplot(x = 'Histomolecular subtype', y='Hazard', data=data, hue='Grade',
+                  palette={"Grade II":"#AFD275" , "Grade III":"#7395AE", "Grade IV":"#E7717D"}, 
+                  size = 4, alpha = 0.9, ax=ax)
+    
+    ax.set_xlabel('') # ax.set_xlabel('Histomolecular subtype', size=16)
+    ax.set_ylabel('') # ax.set_ylabel('Hazard (Z-Score)', size=16)
+    ax.tick_params(axis='y', which='both', labelsize=20)
+    ax.tick_params(axis='x', which='both', labelsize=15)
+    ax.tick_params(axis='x', which='both', labelbottom='off') # doesn't work??
+    ax.legend(prop={'size': 8})
+    fig.savefig(ckpt_name+'/%s_HSP.png' % (model))
+
+
+def makeHazardBoxPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='omic', split='test', zscore=True, agg_type='Hazard_mean'):
+    mpl.rcParams['font.family'] = "arial"
+    data = poolSurvTestPD(ckpt_name, model, split, zscore, 'Hazard_mean')
+    data['Grade'] = data['Grade'].astype(int).astype(str)
+    data['Grade'] = data['Grade'].str.replace('0', 'II', regex=False)
+    data['Grade'] = data['Grade'].str.replace('1', 'III', regex=False)
+    data['Grade'] = data['Grade'].str.replace('2', 'IV', regex=False)
+    
+    fig, axes = plt.subplots(nrows=1, ncols=3, gridspec_kw={'width_ratios': [3, 3, 2]}, dpi=600)
+    plt.subplots_adjust(wspace=0, hspace=0)
+    plt.ylim(-2, 2)
+    plt.yticks(np.arange(-2, 2.001, 1))
+    #color_dict = {0: '#CF9498', 1: '#8CC7C8', 2: '#AAA0C6'}
+    #color_dict = {0: '#F76C6C', 1: '#A8D0E6', 2: '#F8E9A1'}
+    color_dict = ['#F76C6C', '#A8D0E6', '#F8E9A1']
+    subtypes = ['idhwt_ATC', 'idhmut_ATC', 'ODG']
+
+    for i in range(len(subtypes)):
+        axes[i].spines["top"].set_visible(False)
+        axes[i].spines["right"].set_visible(False)
+        axes[i].xaxis.grid(False)
+        axes[i].yaxis.grid(False)
+        
+        if i > 0: 
+            axes[i].get_yaxis().set_visible(False)
+            axes[i].spines["left"].set_visible(False)
+            
+        order = ["II","III","IV"] if subtypes[i] != 'ODG' else ["II", "III"]
+        
+        axes[i].xaxis.label.set_visible(False)
+        axes[i].yaxis.label.set_visible(False)
+        axes[i].tick_params(axis='y', which='both', labelsize=20)
+        axes[i].tick_params(axis='x', which='both', labelsize=15)
+        datapoints = data[data['Histomolecular subtype'] == subtypes[i]]
+        sns.boxplot(y='Hazard', x="Grade", data=datapoints, ax = axes[i], color=color_dict[i], order=order)
+        sns.stripplot(y='Hazard', x='Grade', data=datapoints, alpha=0.2, jitter=0.2, color='k', ax = axes[i], order=order)
+        axes[i].set_ylim(-2.5, 2.5)
+        axes[i].set_yticks(np.arange(-2.0, 2.1, 1))
+        
+    #axes[2].legend(prop={'size': 10})
+    fig.savefig(ckpt_name+'/%s_HBP.png' % (model))
+
+
+def makeAUROCPlot(ckpt_name='./checkpoints/grad_15/', model_list=['path', 'omic', 'pathgraphomic_fusion'], split='test', avg='micro', use_zoom=False):
+    mpl.rcParams['font.family'] = "arial"
+    colors = {'path':'dodgerblue', 'graph':'orange', 'omic':'green', 'pathgraphomic_fusion':'crimson'}
+    names = {'path':'Histology CNN', 'graph':'Histology GCN', 'omic':'Genomic SNN', 'pathgraphomic_fusion':'Pathomic F.'}
+    zoom_params = {0:([0.2, 0.4], [0.8, 1.0]), 
+                   1:([0.25, 0.45], [0.75, 0.95]),
+                   2:([0.0, 0.2], [0.8, 1.0]),
+                   'micro':([0.15, 0.35], [0.8, 1.0])}
+    mean_fpr = np.linspace(0, 1, 100)
+    classes = [0, 1, 2, avg]
+    ### 1. Looping over classes
+    for i in classes:
+        print("Class: " + str(i))
+        fi = pylab.figure(figsize=(10,10), dpi=600, linewidth=0.2)
+        axi = plt.subplot()
+        
+        ### 2. Looping over models
+        for m, model in enumerate(model_list):
+            ignore_missing_moltype = 1 if 'omic' in model else 0
+            ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
+            use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0)
+
+            ###. 3. Looping over all splits
+            tprs, pres, aucrocs, rocaucs, = [], [], [], []
+            for k in range(1,16):
+                pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
+                grade_pred, grade = np.array(pred[3]), np.array(pred[4])
+                enc = LabelBinarizer()
+                enc.fit(grade)
+                grade_oh = enc.transform(grade)
+
+                if i != avg:
+                    pres.append(average_precision_score(grade_oh[:, i], grade_pred[:, i])) # from https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
+                    fpr, tpr, thresh = roc_curve(grade_oh[:,i], grade_pred[:,i], drop_intermediate=False)
+                    aucrocs.append(auc(fpr, tpr)) # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
+                    rocaucs.append(roc_auc_score(grade_oh[:,i], grade_pred[:,i])) # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
+                    tprs.append(interp(mean_fpr, fpr, tpr))
+                    tprs[-1][0] = 0.0
+                else:
+                    # A "micro-average": quantifying score on all classes jointly
+                    pres.append(average_precision_score(grade_oh, grade_pred, average=avg))
+                    fpr, tpr, thresh = roc_curve(grade_oh.ravel(), grade_pred.ravel())
+                    aucrocs.append(auc(fpr, tpr))
+                    rocaucs.append(roc_auc_score(grade_oh, grade_pred, avg))
+                    tprs.append(interp(mean_fpr, fpr, tpr))
+                    tprs[-1][0] = 0.0
+
+            mean_tpr = np.mean(tprs, axis=0)
+            mean_tpr[-1] = 1.0
+            #mean_auc = auc(mean_fpr, mean_tpr)
+            mean_auc = np.mean(aucrocs)
+            std_auc = np.std(aucrocs)
+            print('\t'+'%s - AUC: %0.3f ± %0.3f' % (model, mean_auc, std_auc))
+            
+            if use_zoom:
+                alpha, lw = (0.8, 6) if model =='pathgraphomic_fusion' else (0.5, 6)
+                plt.plot(mean_fpr, mean_tpr, color=colors[model],
+                     label=r'%s (AUC = %0.3f $\pm$ %0.3f)' % (names[model], mean_auc, std_auc), lw=lw, alpha=alpha)
+                std_tpr = np.std(tprs, axis=0)
+                tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
+                tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
+                plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model], alpha=0.1)
+                plt.xlim([zoom_params[i][0][0]-0.005, zoom_params[i][0][1]+0.005])
+                plt.ylim([zoom_params[i][1][0]-0.005, zoom_params[i][1][1]+0.005])
+                axi.set_xticks(np.arange(zoom_params[i][0][0], zoom_params[i][0][1]+0.001, 0.05))
+                axi.set_yticks(np.arange(zoom_params[i][1][0], zoom_params[i][1][1]+0.001, 0.05))
+                axi.tick_params(axis='both', which='major', labelsize=26)
+            else:
+                alpha, lw = (0.8, 4) if model =='pathgraphomic_fusion' else (0.5, 3)
+                plt.plot(mean_fpr, mean_tpr, color=colors[model],
+                     label=r'%s (AUC = %0.3f $\pm$ %0.3f)' % (names[model], mean_auc, std_auc), lw=lw, alpha=alpha)
+                std_tpr = np.std(tprs, axis=0)
+                tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
+                tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
+                plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model], alpha=0.1)
+                plt.xlim([-0.05, 1.05])
+                plt.ylim([-0.05, 1.05])
+                axi.set_xticks(np.arange(0, 1.001, 0.2))
+                axi.set_yticks(np.arange(0, 1.001, 0.2))
+                axi.legend(loc="lower right", prop={'size': 20})
+                axi.tick_params(axis='both', which='major', labelsize=30)
+                #plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='navy', alpha=.8)
+
+    figures = [manager.canvas.figure
+               for manager in mpl._pylab_helpers.Gcf.get_all_fig_managers()]
+    
+    zoom = '_zoom' if use_zoom else ''
+    for i, fig in enumerate(figures):
+        fig.savefig(ckpt_name+'/AUC_%s%s.png' % (classes[i], zoom))
\ No newline at end of file