Diff of /utils.py [000000] .. [352cae]

Switch to unified view

a b/utils.py
1
import os
2
import numpy as np
3
import pandas as pd
4
import random
5
import h5py
6
import pickle
7
import itertools
8
import matplotlib.pyplot as plt
9
import torch
10
from torch.utils.data import Dataset, DataLoader, SequentialSampler
11
12
from sklearn.utils.class_weight import compute_sample_weight
13
from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score, cumulative_dynamic_auc
14
15
16
def seed_worker(worker_id):
17
    worker_seed = torch.initial_seed() % 2**32
18
    np.random.seed(worker_seed)
19
    random.seed(worker_seed)
20
21
def collate(batch):
22
    # Keep a numpy array on items that don't need to be a tensor for training.
23
    img = torch.cat([item[0] for item in batch], dim=0)
24
    img2 = torch.cat([item[1] for item in batch], dim=0)
25
    label = torch.LongTensor([item[2] for item in batch])
26
    event_time = np.array([item[3] for item in batch])
27
    censorship = torch.FloatTensor([item[4] for item in batch])
28
    stage = torch.LongTensor([item[5] for item in batch])
29
    slide_id =  [item[6] for item in batch]
30
    
31
    return [img, img2, label, event_time, censorship, stage, slide_id]
32
    
33
34
class FeatureBagsDataset(Dataset):
35
    def __init__(self, df, data_dir, input_feature_size, stage_class):
36
        self.slide_df = df.copy().reset_index(drop=True)
37
        self.data_dir = data_dir
38
        self.input_feature_size = input_feature_size
39
        self.stage_class = stage_class
40
    
41
    def _get_feature_path(self, slide_id):
42
        return os.path.join(self.data_dir, f"{slide_id}_Mergedfeatures.pt")
43
44
    def __getitem__(self, idx):
45
        slide_id = self.slide_df["slide_id"][idx]
46
        stage = self.slide_df["stage"][idx]
47
        label = self.slide_df["disc_label"][idx]
48
        event_time = self.slide_df["recurrence_years"][idx]
49
        censorship = self.slide_df["censorship"][idx]
50
51
        full_path = self._get_feature_path(slide_id)
52
53
        features = torch.load(full_path)
54
55
        # Merged features. 
56
        features_merged = torch.from_numpy(np.array([x[0].mean(0) for x in features]))
57
58
        # Alternative would be all features depending on what works best. 
59
        features_flattened = torch.from_numpy(np.concatenate([x[0] for x in features]))
60
        
61
        return features_merged, features_flattened, label, event_time, censorship, stage, slide_id
62
63
    def __len__(self):
64
        return len(self.slide_df)
65
66
def define_data_sampling(train_split, val_split, method, workers):
67
    # Reproducibility of DataLoader.
68
    g = torch.Generator()
69
    g.manual_seed(0)
70
71
    # Set up training data sampler.
72
    if method == "random":
73
        print("random sampling setting")
74
        train_loader = DataLoader(
75
            dataset=train_split,
76
            batch_size=1,  # model expects one bag of features at the time.
77
            shuffle=True,
78
            collate_fn=collate,
79
            num_workers=workers,
80
            pin_memory=True,
81
            worker_init_fn=seed_worker,
82
            generator=g,
83
        )
84
    else:
85
        raise Exception(f"Sampling method '{method}' not implemented.")
86
87
    val_loader = DataLoader(
88
            dataset=val_split,
89
            batch_size=1,  # model expects one bag of features at the time.
90
            sampler=SequentialSampler(val_split),
91
            collate_fn=collate,
92
            num_workers=workers,
93
            pin_memory=True,
94
            worker_init_fn=seed_worker,
95
            generator=g,
96
    )
97
98
    return train_loader, val_loader
99
100
class MonitorBestModelEarlyStopping:
101
    """Early stops the training if validation loss doesn't improve after a given patience and save best model """
102
    def __init__(self, patience=15, min_epochs=20, saving_checkpoint=True):
103
        """
104
        Args:
105
            patience (int): How long to wait after last time validation loss improved.
106
                            Default: 20
107
            min_epochs (int): Earliest epoch possible for stopping
108
            verbose (bool): If True, prints a message for each validation loss improvement. 
109
                            Default: False
110
        """
111
        #self.warmup = warmup
112
        self.patience = patience
113
        self.min_epochs = min_epochs
114
        self.counter = 0
115
        self.early_stop = False
116
        
117
        self.eval_loss_min = np.Inf
118
        self.best_loss_score = None
119
        self.best_epoch_loss = None
120
121
        self.best_CI_score = 0.0
122
        self.best_metrics_score = None
123
        self.best_epoch_CI = None
124
125
        self.saving_checkpoint = saving_checkpoint
126
127
    def __call__(self, epoch, eval_loss, eval_cindex, eval_other_metrics, model, log_dir):
128
129
        loss_score = -eval_loss
130
        CI_score = eval_cindex
131
        metrics_score = eval_other_metrics
132
133
        # Save model at epoch 0 and starts monitoring.
134
        if self.best_loss_score is None:
135
            self._update_loss_scores(loss_score, eval_loss, epoch)
136
            self._update_metrics_scores(CI_score, metrics_score, epoch)
137
            #self.save_checkpoint(model, log_dir, epoch)
138
139
        # Eval loss starts increasing. Recommend running early stopping on the loss.
140
        elif loss_score < self.best_loss_score:
141
            self.counter += 1
142
            print(f'Evaluation loss does not decrease : Starting Early stopping counter {self.counter} out of {self.patience}')
143
            if self.counter >= self.patience and epoch > self.min_epochs:
144
                self.early_stop = True
145
        # Eval loss keeps decreasing.
146
        else:
147
            print(f'Epoch {epoch} validation loss decreased ({self.eval_loss_min:.6f} --> {eval_loss:.6f})')
148
            self._update_loss_scores(loss_score, eval_loss, epoch)
149
            #self.save_checkpoint(model, log_dir, epoch)
150
            self.counter = 0
151
        
152
        # We may have a tiny lag between min loss and the best C-index. With the patience and early stop, it is fine but better to save based on highest C-index too. 
153
        if CI_score > self.best_CI_score:
154
            self._update_metrics_scores(CI_score, metrics_score, epoch)
155
            #self.save_checkpoint(model, log_dir, epoch)
156
157
    def save_checkpoint(self, model, log_dir, epoch):
158
        filepath = os.path.join(log_dir, f"{epoch}_checkpoint.pt")
159
        if self.saving_checkpoint and not os.path.exists(filepath):
160
            print(f"Saving model")
161
            torch.save(model.state_dict(), filepath)
162
        
163
    def _update_loss_scores(self, loss_score, eval_loss, epoch):
164
        self.eval_loss_min = eval_loss
165
        self.best_loss_score = loss_score
166
        self.best_epoch_loss = epoch
167
        print(f'Updating loss at epoch {self.best_epoch_loss} -> {self.eval_loss_min:.6f}')
168
    
169
    def _update_metrics_scores(self, CI_score, metrics_score, epoch):
170
        # Even though loss would decrease, C-index does not necessarily increase. Keep track also on best C-index.
171
        self.best_CI_score = CI_score
172
        self.best_epoch_CI = epoch
173
        self.best_metrics_score = metrics_score
174
        print(f'Updating C-index at epoch {self.best_epoch_CI} -> {self.best_CI_score:.6f}')
175
176
def get_lr(optimizer):
177
    for param_group in optimizer.param_groups:
178
        return param_group["lr"]
179
180
def print_model(model):
181
    print(model)
182
    n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
183
    print(f"Model has {n_trainable_params} parameters")
184
185
def get_survival_data_for_BS(df, time_col_name, censorship_col_name='censorship'):
186
    # To compute one survival metric the Brier score (BS), you need a specific format of censorship and times 
187
    # This is to estimate the censoring distribution from. 
188
    # A structured array containing the binary event indicator as first field (1 event occured; 0 censored), and time of event or time of censoring as second field.
189
    
190
    val = df[[censorship_col_name, time_col_name]].values
191
    max_time = df[time_col_name].max()
192
193
    y = np.empty(len(df), dtype=[('cens', '?'), ('time', '<f8')])
194
    for i in range(len(df)):
195
        y[i] = tuple((bool(1-val[i][0]),val[i][1])) # Note that we take the uncensorship status.
196
    return max_time, y
197
198
def get_bins_time_value(df, n_bins, time_col_name, label_time_col_name='disc_label', censorship_col_name='censorship'):
199
    # Retrieve the values of each bin from the dataset.
200
    # Note that this was done on the uncensored cases.
201
    uncensored_df = df[df[censorship_col_name]==0]
202
    labels, q_bins = pd.qcut(uncensored_df[time_col_name], q=n_bins, retbins=True, labels=False)
203
    q_bins[0] = 0
204
    q_bins[-1] = float('inf')
205
206
    # Current q_bins list length == n bins + 1. There is no need to return q_bins[0]==0.
207
    return q_bins[1::]
208
209
def compute_surv_metrics_eval(
210
    bins_values, 
211
    all_survival_probs, 
212
    all_risk_scores,
213
    train_BS, 
214
    test_BS, 
215
    years_of_interest=[1.0, 2.0, 3.0, 5.0], 
216
    yearI_of_interest=1.0,
217
    yearF_of_interest=5.0,
218
    time_step = 0.5):
219
    
220
    # Note that we discretized the continuous time scale into N bins for training, thus having only N survival probabilities at each n time step.
221
    # The Brier score does not return the exact same value if we use the discretize time or the continuous. 
222
    # Because it uses the distriution of censoring, even within each interval with same probability of survival.  
223
    # BS, IBS and AUC are time-dependent and step-dependent scores. See the notebook for some examples.
224
    
225
    #years_of_interest = [1.0, 2.0, 3.0, 5.0] #This can be completely a choice, depending on what makes sense. 
226
    corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest])
227
    
228
    # Note that this requires that survival times survival_test lie within the range of survival times survival_train. 
229
    # This can be achieved by specifying times accordingly, e.g. by setting times[-1] slightly below the maximum expected follow-up time.
230
    _ , BS = brier_score(
231
        survival_train=train_BS, 
232
        survival_test=test_BS, 
233
        estimate=all_survival_probs[:,corresponding_bins], 
234
        times=years_of_interest)
235
236
    # The Integrated Brier Score (IBS) provides an overall calculation of the model performance at all available times. 
237
    # Both time points (at least two) and time step is a pure choice. Note that the time steps has an impact on the value of IBS. 
238
    
239
    #yearI_of_interest, yearF_of_interest = 1.0, 5.0 # Between 1Y and 5Y included.
240
    #time_step = 0.5 #6month
241
    years_of_interest_steps = np.arange(yearI_of_interest, yearF_of_interest+time_step, step=time_step, dtype=float) # Otherwise final year is not included.
242
    corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest_steps])
243
244
    IBS = integrated_brier_score(
245
        survival_train=train_BS, 
246
        survival_test=test_BS, 
247
        estimate=all_survival_probs[:,corresponding_bins], 
248
        times=years_of_interest_steps)
249
250
    # The AUC can be extended to survival data by defining sensitivity (true positive rate) and specificity (true negative rate) as time-dependent measures. 
251
    # Cumulative cases are all individuals that experienced an event prior to or at time t (ti≤t), whereas dynamic controls are those with ti>t. 
252
    # The associated cumulative/dynamic AUC quantifies how well a model can distinguish subjects who fail by a given time (ti≤t) from subjects who fail after this time (ti>t).
253
    cumAUC, meanAUC = cumulative_dynamic_auc(
254
        survival_train=train_BS, 
255
        survival_test=test_BS, 
256
        estimate=all_risk_scores, #the AUC uses the risk scores like the C-index and not the event-free probabilities.
257
        times=years_of_interest_steps, tied_tol=1e-08)
258
    
259
    c_index_ipwc = concordance_index_ipcw(
260
        survival_train=train_BS,
261
        survival_test=test_BS,
262
        estimate=all_risk_scores, 
263
        tied_tol=1e-08)[0]
264
265
    return (BS, years_of_interest), (IBS, yearI_of_interest, yearF_of_interest), (cumAUC, meanAUC), (c_index_ipwc)