Diff of /utils.py [000000] .. [352cae]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,265 @@
+import os
+import numpy as np
+import pandas as pd
+import random
+import h5py
+import pickle
+import itertools
+import matplotlib.pyplot as plt
+import torch
+from torch.utils.data import Dataset, DataLoader, SequentialSampler
+
+from sklearn.utils.class_weight import compute_sample_weight
+from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score, cumulative_dynamic_auc
+
+
+def seed_worker(worker_id):
+    worker_seed = torch.initial_seed() % 2**32
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
+
+def collate(batch):
+    # Keep a numpy array on items that don't need to be a tensor for training.
+    img = torch.cat([item[0] for item in batch], dim=0)
+    img2 = torch.cat([item[1] for item in batch], dim=0)
+    label = torch.LongTensor([item[2] for item in batch])
+    event_time = np.array([item[3] for item in batch])
+    censorship = torch.FloatTensor([item[4] for item in batch])
+    stage = torch.LongTensor([item[5] for item in batch])
+    slide_id =  [item[6] for item in batch]
+    
+    return [img, img2, label, event_time, censorship, stage, slide_id]
+    
+
+class FeatureBagsDataset(Dataset):
+    def __init__(self, df, data_dir, input_feature_size, stage_class):
+        self.slide_df = df.copy().reset_index(drop=True)
+        self.data_dir = data_dir
+        self.input_feature_size = input_feature_size
+        self.stage_class = stage_class
+    
+    def _get_feature_path(self, slide_id):
+        return os.path.join(self.data_dir, f"{slide_id}_Mergedfeatures.pt")
+
+    def __getitem__(self, idx):
+        slide_id = self.slide_df["slide_id"][idx]
+        stage = self.slide_df["stage"][idx]
+        label = self.slide_df["disc_label"][idx]
+        event_time = self.slide_df["recurrence_years"][idx]
+        censorship = self.slide_df["censorship"][idx]
+
+        full_path = self._get_feature_path(slide_id)
+
+        features = torch.load(full_path)
+
+        # Merged features. 
+        features_merged = torch.from_numpy(np.array([x[0].mean(0) for x in features]))
+
+        # Alternative would be all features depending on what works best. 
+        features_flattened = torch.from_numpy(np.concatenate([x[0] for x in features]))
+    	
+        return features_merged, features_flattened, label, event_time, censorship, stage, slide_id
+
+    def __len__(self):
+        return len(self.slide_df)
+
+def define_data_sampling(train_split, val_split, method, workers):
+    # Reproducibility of DataLoader.
+    g = torch.Generator()
+    g.manual_seed(0)
+
+    # Set up training data sampler.
+    if method == "random":
+        print("random sampling setting")
+        train_loader = DataLoader(
+            dataset=train_split,
+            batch_size=1,  # model expects one bag of features at the time.
+            shuffle=True,
+            collate_fn=collate,
+            num_workers=workers,
+            pin_memory=True,
+            worker_init_fn=seed_worker,
+            generator=g,
+        )
+    else:
+        raise Exception(f"Sampling method '{method}' not implemented.")
+
+    val_loader = DataLoader(
+            dataset=val_split,
+            batch_size=1,  # model expects one bag of features at the time.
+            sampler=SequentialSampler(val_split),
+            collate_fn=collate,
+            num_workers=workers,
+            pin_memory=True,
+            worker_init_fn=seed_worker,
+            generator=g,
+    )
+
+    return train_loader, val_loader
+
+class MonitorBestModelEarlyStopping:
+    """Early stops the training if validation loss doesn't improve after a given patience and save best model """
+    def __init__(self, patience=15, min_epochs=20, saving_checkpoint=True):
+        """
+        Args:
+            patience (int): How long to wait after last time validation loss improved.
+                            Default: 20
+            min_epochs (int): Earliest epoch possible for stopping
+            verbose (bool): If True, prints a message for each validation loss improvement. 
+                            Default: False
+        """
+        #self.warmup = warmup
+        self.patience = patience
+        self.min_epochs = min_epochs
+        self.counter = 0
+        self.early_stop = False
+        
+        self.eval_loss_min = np.Inf
+        self.best_loss_score = None
+        self.best_epoch_loss = None
+
+        self.best_CI_score = 0.0
+        self.best_metrics_score = None
+        self.best_epoch_CI = None
+
+        self.saving_checkpoint = saving_checkpoint
+
+    def __call__(self, epoch, eval_loss, eval_cindex, eval_other_metrics, model, log_dir):
+
+        loss_score = -eval_loss
+        CI_score = eval_cindex
+        metrics_score = eval_other_metrics
+
+        # Save model at epoch 0 and starts monitoring.
+        if self.best_loss_score is None:
+            self._update_loss_scores(loss_score, eval_loss, epoch)
+            self._update_metrics_scores(CI_score, metrics_score, epoch)
+            #self.save_checkpoint(model, log_dir, epoch)
+
+        # Eval loss starts increasing. Recommend running early stopping on the loss.
+        elif loss_score < self.best_loss_score:
+            self.counter += 1
+            print(f'Evaluation loss does not decrease : Starting Early stopping counter {self.counter} out of {self.patience}')
+            if self.counter >= self.patience and epoch > self.min_epochs:
+                self.early_stop = True
+        # Eval loss keeps decreasing.
+        else:
+            print(f'Epoch {epoch} validation loss decreased ({self.eval_loss_min:.6f} --> {eval_loss:.6f})')
+            self._update_loss_scores(loss_score, eval_loss, epoch)
+            #self.save_checkpoint(model, log_dir, epoch)
+            self.counter = 0
+        
+        # We may have a tiny lag between min loss and the best C-index. With the patience and early stop, it is fine but better to save based on highest C-index too. 
+        if CI_score > self.best_CI_score:
+            self._update_metrics_scores(CI_score, metrics_score, epoch)
+            #self.save_checkpoint(model, log_dir, epoch)
+
+    def save_checkpoint(self, model, log_dir, epoch):
+        filepath = os.path.join(log_dir, f"{epoch}_checkpoint.pt")
+        if self.saving_checkpoint and not os.path.exists(filepath):
+            print(f"Saving model")
+            torch.save(model.state_dict(), filepath)
+        
+    def _update_loss_scores(self, loss_score, eval_loss, epoch):
+        self.eval_loss_min = eval_loss
+        self.best_loss_score = loss_score
+        self.best_epoch_loss = epoch
+        print(f'Updating loss at epoch {self.best_epoch_loss} -> {self.eval_loss_min:.6f}')
+    
+    def _update_metrics_scores(self, CI_score, metrics_score, epoch):
+        # Even though loss would decrease, C-index does not necessarily increase. Keep track also on best C-index.
+        self.best_CI_score = CI_score
+        self.best_epoch_CI = epoch
+        self.best_metrics_score = metrics_score
+        print(f'Updating C-index at epoch {self.best_epoch_CI} -> {self.best_CI_score:.6f}')
+
+def get_lr(optimizer):
+    for param_group in optimizer.param_groups:
+        return param_group["lr"]
+
+def print_model(model):
+    print(model)
+    n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    print(f"Model has {n_trainable_params} parameters")
+
+def get_survival_data_for_BS(df, time_col_name, censorship_col_name='censorship'):
+    # To compute one survival metric the Brier score (BS), you need a specific format of censorship and times 
+    # This is to estimate the censoring distribution from. 
+    # A structured array containing the binary event indicator as first field (1 event occured; 0 censored), and time of event or time of censoring as second field.
+    
+    val = df[[censorship_col_name, time_col_name]].values
+    max_time = df[time_col_name].max()
+
+    y = np.empty(len(df), dtype=[('cens', '?'), ('time', '<f8')])
+    for i in range(len(df)):
+        y[i] = tuple((bool(1-val[i][0]),val[i][1])) # Note that we take the uncensorship status.
+    return max_time, y
+
+def get_bins_time_value(df, n_bins, time_col_name, label_time_col_name='disc_label', censorship_col_name='censorship'):
+    # Retrieve the values of each bin from the dataset.
+    # Note that this was done on the uncensored cases.
+    uncensored_df = df[df[censorship_col_name]==0]
+    labels, q_bins = pd.qcut(uncensored_df[time_col_name], q=n_bins, retbins=True, labels=False)
+    q_bins[0] = 0
+    q_bins[-1] = float('inf')
+
+    # Current q_bins list length == n bins + 1. There is no need to return q_bins[0]==0.
+    return q_bins[1::]
+
+def compute_surv_metrics_eval(
+    bins_values, 
+    all_survival_probs, 
+    all_risk_scores,
+    train_BS, 
+    test_BS, 
+    years_of_interest=[1.0, 2.0, 3.0, 5.0], 
+    yearI_of_interest=1.0,
+    yearF_of_interest=5.0,
+    time_step = 0.5):
+    
+    # Note that we discretized the continuous time scale into N bins for training, thus having only N survival probabilities at each n time step.
+    # The Brier score does not return the exact same value if we use the discretize time or the continuous. 
+    # Because it uses the distriution of censoring, even within each interval with same probability of survival.  
+    # BS, IBS and AUC are time-dependent and step-dependent scores. See the notebook for some examples.
+    
+    #years_of_interest = [1.0, 2.0, 3.0, 5.0] #This can be completely a choice, depending on what makes sense. 
+    corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest])
+    
+    # Note that this requires that survival times survival_test lie within the range of survival times survival_train. 
+    # This can be achieved by specifying times accordingly, e.g. by setting times[-1] slightly below the maximum expected follow-up time.
+    _ , BS = brier_score(
+        survival_train=train_BS, 
+        survival_test=test_BS, 
+        estimate=all_survival_probs[:,corresponding_bins], 
+        times=years_of_interest)
+
+    # The Integrated Brier Score (IBS) provides an overall calculation of the model performance at all available times. 
+    # Both time points (at least two) and time step is a pure choice. Note that the time steps has an impact on the value of IBS. 
+    
+    #yearI_of_interest, yearF_of_interest = 1.0, 5.0 # Between 1Y and 5Y included.
+    #time_step = 0.5 #6month
+    years_of_interest_steps = np.arange(yearI_of_interest, yearF_of_interest+time_step, step=time_step, dtype=float) # Otherwise final year is not included.
+    corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest_steps])
+
+    IBS = integrated_brier_score(
+        survival_train=train_BS, 
+        survival_test=test_BS, 
+        estimate=all_survival_probs[:,corresponding_bins], 
+        times=years_of_interest_steps)
+
+    # The AUC can be extended to survival data by defining sensitivity (true positive rate) and specificity (true negative rate) as time-dependent measures. 
+    # Cumulative cases are all individuals that experienced an event prior to or at time t (ti≤t), whereas dynamic controls are those with ti>t. 
+    # The associated cumulative/dynamic AUC quantifies how well a model can distinguish subjects who fail by a given time (ti≤t) from subjects who fail after this time (ti>t).
+    cumAUC, meanAUC = cumulative_dynamic_auc(
+        survival_train=train_BS, 
+        survival_test=test_BS, 
+        estimate=all_risk_scores, #the AUC uses the risk scores like the C-index and not the event-free probabilities.
+        times=years_of_interest_steps, tied_tol=1e-08)
+    
+    c_index_ipwc = concordance_index_ipcw(
+        survival_train=train_BS,
+        survival_test=test_BS,
+        estimate=all_risk_scores, 
+        tied_tol=1e-08)[0]
+
+    return (BS, years_of_interest), (IBS, yearI_of_interest, yearF_of_interest), (cumAUC, meanAUC), (c_index_ipwc)
\ No newline at end of file