|
a |
|
b/utils.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import pandas as pd |
|
|
4 |
import random |
|
|
5 |
import h5py |
|
|
6 |
import pickle |
|
|
7 |
import itertools |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import torch |
|
|
10 |
from torch.utils.data import Dataset, DataLoader, SequentialSampler |
|
|
11 |
|
|
|
12 |
from sklearn.utils.class_weight import compute_sample_weight |
|
|
13 |
from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score, cumulative_dynamic_auc |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def seed_worker(worker_id): |
|
|
17 |
worker_seed = torch.initial_seed() % 2**32 |
|
|
18 |
np.random.seed(worker_seed) |
|
|
19 |
random.seed(worker_seed) |
|
|
20 |
|
|
|
21 |
def collate(batch): |
|
|
22 |
# Keep a numpy array on items that don't need to be a tensor for training. |
|
|
23 |
img = torch.cat([item[0] for item in batch], dim=0) |
|
|
24 |
img2 = torch.cat([item[1] for item in batch], dim=0) |
|
|
25 |
label = torch.LongTensor([item[2] for item in batch]) |
|
|
26 |
event_time = np.array([item[3] for item in batch]) |
|
|
27 |
censorship = torch.FloatTensor([item[4] for item in batch]) |
|
|
28 |
stage = torch.LongTensor([item[5] for item in batch]) |
|
|
29 |
slide_id = [item[6] for item in batch] |
|
|
30 |
|
|
|
31 |
return [img, img2, label, event_time, censorship, stage, slide_id] |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
class FeatureBagsDataset(Dataset): |
|
|
35 |
def __init__(self, df, data_dir, input_feature_size, stage_class): |
|
|
36 |
self.slide_df = df.copy().reset_index(drop=True) |
|
|
37 |
self.data_dir = data_dir |
|
|
38 |
self.input_feature_size = input_feature_size |
|
|
39 |
self.stage_class = stage_class |
|
|
40 |
|
|
|
41 |
def _get_feature_path(self, slide_id): |
|
|
42 |
return os.path.join(self.data_dir, f"{slide_id}_Mergedfeatures.pt") |
|
|
43 |
|
|
|
44 |
def __getitem__(self, idx): |
|
|
45 |
slide_id = self.slide_df["slide_id"][idx] |
|
|
46 |
stage = self.slide_df["stage"][idx] |
|
|
47 |
label = self.slide_df["disc_label"][idx] |
|
|
48 |
event_time = self.slide_df["recurrence_years"][idx] |
|
|
49 |
censorship = self.slide_df["censorship"][idx] |
|
|
50 |
|
|
|
51 |
full_path = self._get_feature_path(slide_id) |
|
|
52 |
|
|
|
53 |
features = torch.load(full_path) |
|
|
54 |
|
|
|
55 |
# Merged features. |
|
|
56 |
features_merged = torch.from_numpy(np.array([x[0].mean(0) for x in features])) |
|
|
57 |
|
|
|
58 |
# Alternative would be all features depending on what works best. |
|
|
59 |
features_flattened = torch.from_numpy(np.concatenate([x[0] for x in features])) |
|
|
60 |
|
|
|
61 |
return features_merged, features_flattened, label, event_time, censorship, stage, slide_id |
|
|
62 |
|
|
|
63 |
def __len__(self): |
|
|
64 |
return len(self.slide_df) |
|
|
65 |
|
|
|
66 |
def define_data_sampling(train_split, val_split, method, workers): |
|
|
67 |
# Reproducibility of DataLoader. |
|
|
68 |
g = torch.Generator() |
|
|
69 |
g.manual_seed(0) |
|
|
70 |
|
|
|
71 |
# Set up training data sampler. |
|
|
72 |
if method == "random": |
|
|
73 |
print("random sampling setting") |
|
|
74 |
train_loader = DataLoader( |
|
|
75 |
dataset=train_split, |
|
|
76 |
batch_size=1, # model expects one bag of features at the time. |
|
|
77 |
shuffle=True, |
|
|
78 |
collate_fn=collate, |
|
|
79 |
num_workers=workers, |
|
|
80 |
pin_memory=True, |
|
|
81 |
worker_init_fn=seed_worker, |
|
|
82 |
generator=g, |
|
|
83 |
) |
|
|
84 |
else: |
|
|
85 |
raise Exception(f"Sampling method '{method}' not implemented.") |
|
|
86 |
|
|
|
87 |
val_loader = DataLoader( |
|
|
88 |
dataset=val_split, |
|
|
89 |
batch_size=1, # model expects one bag of features at the time. |
|
|
90 |
sampler=SequentialSampler(val_split), |
|
|
91 |
collate_fn=collate, |
|
|
92 |
num_workers=workers, |
|
|
93 |
pin_memory=True, |
|
|
94 |
worker_init_fn=seed_worker, |
|
|
95 |
generator=g, |
|
|
96 |
) |
|
|
97 |
|
|
|
98 |
return train_loader, val_loader |
|
|
99 |
|
|
|
100 |
class MonitorBestModelEarlyStopping: |
|
|
101 |
"""Early stops the training if validation loss doesn't improve after a given patience and save best model """ |
|
|
102 |
def __init__(self, patience=15, min_epochs=20, saving_checkpoint=True): |
|
|
103 |
""" |
|
|
104 |
Args: |
|
|
105 |
patience (int): How long to wait after last time validation loss improved. |
|
|
106 |
Default: 20 |
|
|
107 |
min_epochs (int): Earliest epoch possible for stopping |
|
|
108 |
verbose (bool): If True, prints a message for each validation loss improvement. |
|
|
109 |
Default: False |
|
|
110 |
""" |
|
|
111 |
#self.warmup = warmup |
|
|
112 |
self.patience = patience |
|
|
113 |
self.min_epochs = min_epochs |
|
|
114 |
self.counter = 0 |
|
|
115 |
self.early_stop = False |
|
|
116 |
|
|
|
117 |
self.eval_loss_min = np.Inf |
|
|
118 |
self.best_loss_score = None |
|
|
119 |
self.best_epoch_loss = None |
|
|
120 |
|
|
|
121 |
self.best_CI_score = 0.0 |
|
|
122 |
self.best_metrics_score = None |
|
|
123 |
self.best_epoch_CI = None |
|
|
124 |
|
|
|
125 |
self.saving_checkpoint = saving_checkpoint |
|
|
126 |
|
|
|
127 |
def __call__(self, epoch, eval_loss, eval_cindex, eval_other_metrics, model, log_dir): |
|
|
128 |
|
|
|
129 |
loss_score = -eval_loss |
|
|
130 |
CI_score = eval_cindex |
|
|
131 |
metrics_score = eval_other_metrics |
|
|
132 |
|
|
|
133 |
# Save model at epoch 0 and starts monitoring. |
|
|
134 |
if self.best_loss_score is None: |
|
|
135 |
self._update_loss_scores(loss_score, eval_loss, epoch) |
|
|
136 |
self._update_metrics_scores(CI_score, metrics_score, epoch) |
|
|
137 |
#self.save_checkpoint(model, log_dir, epoch) |
|
|
138 |
|
|
|
139 |
# Eval loss starts increasing. Recommend running early stopping on the loss. |
|
|
140 |
elif loss_score < self.best_loss_score: |
|
|
141 |
self.counter += 1 |
|
|
142 |
print(f'Evaluation loss does not decrease : Starting Early stopping counter {self.counter} out of {self.patience}') |
|
|
143 |
if self.counter >= self.patience and epoch > self.min_epochs: |
|
|
144 |
self.early_stop = True |
|
|
145 |
# Eval loss keeps decreasing. |
|
|
146 |
else: |
|
|
147 |
print(f'Epoch {epoch} validation loss decreased ({self.eval_loss_min:.6f} --> {eval_loss:.6f})') |
|
|
148 |
self._update_loss_scores(loss_score, eval_loss, epoch) |
|
|
149 |
#self.save_checkpoint(model, log_dir, epoch) |
|
|
150 |
self.counter = 0 |
|
|
151 |
|
|
|
152 |
# We may have a tiny lag between min loss and the best C-index. With the patience and early stop, it is fine but better to save based on highest C-index too. |
|
|
153 |
if CI_score > self.best_CI_score: |
|
|
154 |
self._update_metrics_scores(CI_score, metrics_score, epoch) |
|
|
155 |
#self.save_checkpoint(model, log_dir, epoch) |
|
|
156 |
|
|
|
157 |
def save_checkpoint(self, model, log_dir, epoch): |
|
|
158 |
filepath = os.path.join(log_dir, f"{epoch}_checkpoint.pt") |
|
|
159 |
if self.saving_checkpoint and not os.path.exists(filepath): |
|
|
160 |
print(f"Saving model") |
|
|
161 |
torch.save(model.state_dict(), filepath) |
|
|
162 |
|
|
|
163 |
def _update_loss_scores(self, loss_score, eval_loss, epoch): |
|
|
164 |
self.eval_loss_min = eval_loss |
|
|
165 |
self.best_loss_score = loss_score |
|
|
166 |
self.best_epoch_loss = epoch |
|
|
167 |
print(f'Updating loss at epoch {self.best_epoch_loss} -> {self.eval_loss_min:.6f}') |
|
|
168 |
|
|
|
169 |
def _update_metrics_scores(self, CI_score, metrics_score, epoch): |
|
|
170 |
# Even though loss would decrease, C-index does not necessarily increase. Keep track also on best C-index. |
|
|
171 |
self.best_CI_score = CI_score |
|
|
172 |
self.best_epoch_CI = epoch |
|
|
173 |
self.best_metrics_score = metrics_score |
|
|
174 |
print(f'Updating C-index at epoch {self.best_epoch_CI} -> {self.best_CI_score:.6f}') |
|
|
175 |
|
|
|
176 |
def get_lr(optimizer): |
|
|
177 |
for param_group in optimizer.param_groups: |
|
|
178 |
return param_group["lr"] |
|
|
179 |
|
|
|
180 |
def print_model(model): |
|
|
181 |
print(model) |
|
|
182 |
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
183 |
print(f"Model has {n_trainable_params} parameters") |
|
|
184 |
|
|
|
185 |
def get_survival_data_for_BS(df, time_col_name, censorship_col_name='censorship'): |
|
|
186 |
# To compute one survival metric the Brier score (BS), you need a specific format of censorship and times |
|
|
187 |
# This is to estimate the censoring distribution from. |
|
|
188 |
# A structured array containing the binary event indicator as first field (1 event occured; 0 censored), and time of event or time of censoring as second field. |
|
|
189 |
|
|
|
190 |
val = df[[censorship_col_name, time_col_name]].values |
|
|
191 |
max_time = df[time_col_name].max() |
|
|
192 |
|
|
|
193 |
y = np.empty(len(df), dtype=[('cens', '?'), ('time', '<f8')]) |
|
|
194 |
for i in range(len(df)): |
|
|
195 |
y[i] = tuple((bool(1-val[i][0]),val[i][1])) # Note that we take the uncensorship status. |
|
|
196 |
return max_time, y |
|
|
197 |
|
|
|
198 |
def get_bins_time_value(df, n_bins, time_col_name, label_time_col_name='disc_label', censorship_col_name='censorship'): |
|
|
199 |
# Retrieve the values of each bin from the dataset. |
|
|
200 |
# Note that this was done on the uncensored cases. |
|
|
201 |
uncensored_df = df[df[censorship_col_name]==0] |
|
|
202 |
labels, q_bins = pd.qcut(uncensored_df[time_col_name], q=n_bins, retbins=True, labels=False) |
|
|
203 |
q_bins[0] = 0 |
|
|
204 |
q_bins[-1] = float('inf') |
|
|
205 |
|
|
|
206 |
# Current q_bins list length == n bins + 1. There is no need to return q_bins[0]==0. |
|
|
207 |
return q_bins[1::] |
|
|
208 |
|
|
|
209 |
def compute_surv_metrics_eval( |
|
|
210 |
bins_values, |
|
|
211 |
all_survival_probs, |
|
|
212 |
all_risk_scores, |
|
|
213 |
train_BS, |
|
|
214 |
test_BS, |
|
|
215 |
years_of_interest=[1.0, 2.0, 3.0, 5.0], |
|
|
216 |
yearI_of_interest=1.0, |
|
|
217 |
yearF_of_interest=5.0, |
|
|
218 |
time_step = 0.5): |
|
|
219 |
|
|
|
220 |
# Note that we discretized the continuous time scale into N bins for training, thus having only N survival probabilities at each n time step. |
|
|
221 |
# The Brier score does not return the exact same value if we use the discretize time or the continuous. |
|
|
222 |
# Because it uses the distriution of censoring, even within each interval with same probability of survival. |
|
|
223 |
# BS, IBS and AUC are time-dependent and step-dependent scores. See the notebook for some examples. |
|
|
224 |
|
|
|
225 |
#years_of_interest = [1.0, 2.0, 3.0, 5.0] #This can be completely a choice, depending on what makes sense. |
|
|
226 |
corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest]) |
|
|
227 |
|
|
|
228 |
# Note that this requires that survival times survival_test lie within the range of survival times survival_train. |
|
|
229 |
# This can be achieved by specifying times accordingly, e.g. by setting times[-1] slightly below the maximum expected follow-up time. |
|
|
230 |
_ , BS = brier_score( |
|
|
231 |
survival_train=train_BS, |
|
|
232 |
survival_test=test_BS, |
|
|
233 |
estimate=all_survival_probs[:,corresponding_bins], |
|
|
234 |
times=years_of_interest) |
|
|
235 |
|
|
|
236 |
# The Integrated Brier Score (IBS) provides an overall calculation of the model performance at all available times. |
|
|
237 |
# Both time points (at least two) and time step is a pure choice. Note that the time steps has an impact on the value of IBS. |
|
|
238 |
|
|
|
239 |
#yearI_of_interest, yearF_of_interest = 1.0, 5.0 # Between 1Y and 5Y included. |
|
|
240 |
#time_step = 0.5 #6month |
|
|
241 |
years_of_interest_steps = np.arange(yearI_of_interest, yearF_of_interest+time_step, step=time_step, dtype=float) # Otherwise final year is not included. |
|
|
242 |
corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest_steps]) |
|
|
243 |
|
|
|
244 |
IBS = integrated_brier_score( |
|
|
245 |
survival_train=train_BS, |
|
|
246 |
survival_test=test_BS, |
|
|
247 |
estimate=all_survival_probs[:,corresponding_bins], |
|
|
248 |
times=years_of_interest_steps) |
|
|
249 |
|
|
|
250 |
# The AUC can be extended to survival data by defining sensitivity (true positive rate) and specificity (true negative rate) as time-dependent measures. |
|
|
251 |
# Cumulative cases are all individuals that experienced an event prior to or at time t (ti≤t), whereas dynamic controls are those with ti>t. |
|
|
252 |
# The associated cumulative/dynamic AUC quantifies how well a model can distinguish subjects who fail by a given time (ti≤t) from subjects who fail after this time (ti>t). |
|
|
253 |
cumAUC, meanAUC = cumulative_dynamic_auc( |
|
|
254 |
survival_train=train_BS, |
|
|
255 |
survival_test=test_BS, |
|
|
256 |
estimate=all_risk_scores, #the AUC uses the risk scores like the C-index and not the event-free probabilities. |
|
|
257 |
times=years_of_interest_steps, tied_tol=1e-08) |
|
|
258 |
|
|
|
259 |
c_index_ipwc = concordance_index_ipcw( |
|
|
260 |
survival_train=train_BS, |
|
|
261 |
survival_test=test_BS, |
|
|
262 |
estimate=all_risk_scores, |
|
|
263 |
tied_tol=1e-08)[0] |
|
|
264 |
|
|
|
265 |
return (BS, years_of_interest), (IBS, yearI_of_interest, yearF_of_interest), (cumAUC, meanAUC), (c_index_ipwc) |