a b/utils/eval_utils_survival.py
1
import matplotlib.pyplot as plt
2
import numpy as np
3
import os
4
import pandas as pd
5
import torch
6
7
from datasets.dataset_survival import Generic_MIL_Survival_Dataset
8
from lifelines.utils import concordance_index
9
from models.model_amil import AMIL
10
from models.model_mil import MIL_fc_Surv
11
from pycox.evaluation import EvalSurv
12
from utils.utils import *
13
14
15
def initiate_model(settings, ckpt_path):
16
    print('Initialize model ...', end=' ')
17
    model_dict = {"dropout": settings['drop_out']}
18
    
19
    if settings['model_size'] is not None and settings['model_type'] == 'amil':
20
        model_dict.update({"size_arg": settings['model_size']})
21
    
22
    if settings['model_type'] =='amil':
23
        model = AMIL(**model_dict)
24
    elif settings['model_type'] == 'mil':
25
        model = MIL_fc_Surv(**model_dict)
26
    else:
27
        raise NotImplementedError
28
29
    ckpt = torch.load(ckpt_path)
30
    ckpt_clean = {}
31
    for key in ckpt.keys():
32
        if 'instance_loss_fn' in key:
33
            continue
34
        ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
35
    model.load_state_dict(ckpt_clean, strict=True)
36
37
    model.relocate()
38
    model.eval()
39
    print('Done.')
40
41
    if settings['print_model_info']:
42
        print_network(model)  
43
44
    return model
45
46
47
48
class _BaseEvaluationData:
49
    event_col = 'event'
50
    time_col = 'time'
51
    
52
    def __init__(self, settings):
53
        print('Initialize data ...', end=' ')
54
        self.dataset = Generic_MIL_Survival_Dataset(csv_path = settings['csv_path'],
55
            data_dir= os.path.join(settings['data_root_dir'], settings['feature_dir']),
56
            shuffle = False, 
57
            print_info = settings['print_data_info'],
58
            label_dict = {'lebt':0, 'tod':1},
59
            event_col = self.event_col,
60
            time_col = self.time_col,
61
            patient_strat=True,
62
            ignore=[])
63
64
        self.split_path = '{}/splits_{}.csv'.format(settings['split_dir'], settings['split_idx'])
65
        print('Done.')
66
67
    def _get_split_data(self, split):
68
        assert split in ['train', 'val', 'test', 'all'], 'Split {} not recognized, must be in [train, val, test, all]'.format(split)
69
        train, val, test = self.dataset.return_splits(from_id=False, csv_path=self.split_path)
70
        if split == 'train':
71
            loader = get_simple_loader(train, survival=True)
72
        elif split == 'val':
73
            loader = get_simple_loader(val, survival=True)
74
        elif split == 'test':
75
            loader = get_simple_loader(test, survival=True)
76
        elif split == 'all':
77
            loader = get_simple_loader(self.dataset, survival=True)
78
        return loader, loader.dataset.slide_data
79
80
81
82
class _BaseEvaluationAMIL(_BaseEvaluationData):
83
    def __init__(self, settings):
84
        super().__init__(settings)
85
86
        # init model
87
        ckpt_path = os.path.join(settings['models_dir'], 's_{}_checkpoint.pt'.format(settings['split_idx']))
88
        self.model = initiate_model(settings, ckpt_path)
89
90
        self.baseline_hazard = None
91
        self.baseline_cumulutative_hazard = None
92
93
        self.patient_predictions = None
94
        self.c_index = None
95
        self.c_index_td = None
96
        self.ibs = None
97
        self.inbll = None
98
99
100
    def _compute_risk(self, loader):
101
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
        risks = []
103
        events = []
104
        times = []
105
        # print('Collect patient predictions ...', end=' ')
106
        for batch_idx, (data, event, time) in enumerate(loader):
107
            with torch.no_grad():
108
                risk, _ , _ = self.model(data.to(device))
109
            risks.append(risk.item())
110
            events.append(event.item())
111
            times.append(time.item())
112
        # print('Done.')
113
        return np.asarray(times), np.asarray(events), np.asarray(risks)
114
            
115
116
    def _compute_baseline_harzards(self):
117
        """Computes the Breslow esimates from the training data.
118
119
        Modified from https://github.com/havakv/pycox/blob/0e9d6f9a1eff88a355ead11f0aa68bfb94647bf8/pycox/models/cox.py#L63        
120
        """
121
        loader, dataset = self._get_split_data('train')
122
        _, _, risk_scores = self._compute_risk(loader)
123
        return  (dataset 
124
                .assign(exp_risk=np.exp(risk_scores)) 
125
                .groupby(dataset.time) 
126
                .agg({'exp_risk': 'sum', 'event': 'sum'})
127
                .sort_index(ascending=False) 
128
                .assign(exp_risk=lambda x: x['exp_risk'].cumsum())  
129
                .pipe(lambda x: x['event']/x['exp_risk']) 
130
                .iloc[::-1]
131
                .rename('baseline_hazards'))
132
133
    def _compute_baseline_cumulative_hazards(self):
134
        """Computes baseline and baseline cumulative hazards and stores as class variable"""
135
        print('Estimate baseline cumulative hazard ...', end=' ')
136
        base_hazard = self._compute_baseline_harzards()
137
        self.baseline_hazard = base_hazard
138
        self.baseline_cumulutative_hazard = base_hazard.cumsum().rename('baseline_cumulative_hazards')
139
        print('Done.')
140
141
142
    def _predict_survival_function(self, loader):
143
        """Predicts survival function for given data loader."""
144
        if self.baseline_cumulutative_hazard is None:
145
            self._compute_baseline_cumulative_hazards()
146
147
        base_ch = self.baseline_cumulutative_hazard.values.reshape(-1, 1).astype(float)
148
        times, events, risks = self._compute_risk(loader)
149
        exp_risk = np.exp(risks).reshape(1, -1)
150
        surv = np.exp(-base_ch.dot(exp_risk))
151
        return times, events, torch.from_numpy(surv)
152
153
    def _predict_risk(self, loader):
154
        times, events, risks = self._compute_risk(loader)
155
        return times, events, risks
156
157
158
    def _collect_patient_ids(self, split):
159
        loader, dataset = self._get_split_data(split)
160
        return dataset.index
161
162
163
    def _unpack_data(self, data):
164
        times = [data[patient]['time'] for patient in data]
165
        events = [data[patient]['event'] for patient in data]
166
        predictions = [data[patient]['probabilities'] for patient in data]
167
        return times, events, predictions
168
169
    def _compute_c_index(self, data):
170
        times, events, predictions = self._unpack_data(data)
171
        probs_by_interval = torch.stack(predictions).permute(1, 0)
172
        c_index = [concordance_index(event_times=times,
173
                                     predicted_scores=interval_probs,
174
                                     event_observed=events)
175
                   for interval_probs in probs_by_interval]
176
        return c_index
177
178
    def _predictions_to_pycox(self, data, time_points=None):
179
        predictions = {k: v['probabilities'] for k, v in data.items()}
180
        df = pd.DataFrame.from_dict(predictions)
181
        return df
182
183
184
185
class EvaluationAMIL(_BaseEvaluationAMIL):
186
    def __init__(self, settings):
187
        super().__init__(settings)
188
189
        self.split = None
190
191
    def _check_split_data(self, split):
192
        if self.split is None:
193
            self.split = split
194
        elif self.split != split:
195
            self.patient_predictions = None
196
            self.c_index = None
197
            self.c_index_td = None
198
            self.ibs = None
199
            self.inbll = None
200
201
    def _collect_patient_predictions(self, split):
202
        patient_data = dict()
203
        loader, _ = self._get_split_data(split)
204
        pids = self._collect_patient_ids(split)
205
        times, events, surv = self._predict_survival_function(loader)
206
        for i, patient in enumerate(pids):
207
            patient_data[patient] = {'time': times[i],
208
                                        'event': events[i],
209
                                        'probabilities': surv[:, i]}
210
        return patient_data
211
212
213
    def _compute_pycox_metrics(self, data, time_points=None,
214
                               drop_last_times=0):
215
        times, events, _ = self._unpack_data(data)
216
        times, events = np.array(times), np.array(events)
217
        predictions = self._predictions_to_pycox(data, time_points)
218
219
        ev = EvalSurv(predictions, times, events, censor_surv='km')
220
        # Using "antolini" method instead of "adj_antolini" resulted in Ctd
221
        # values different from C-index for proportional hazard methods (for
222
        # CNV data); this is presumably due to the tie handling, since that is
223
        # what the pycox authors "adjust" (see code comments at:
224
        # https://github.com/havakv/pycox/blob/6ed3973954789f54453055bbeb85887ded2fb81c/pycox/evaluation/eval_surv.py#L171)
225
        # c_index_td = ev.concordance_td('antolini')
226
        c_index_td = ev.concordance_td('adj_antolini')
227
228
        # time_grid = np.array(predictions.index)
229
        # Use 100-point time grid based on data
230
        time_grid = np.linspace(times.min(), times.max(), 100)
231
        # Since the score becomes unstable for the highest times, drop the last
232
        # time points?
233
        if drop_last_times > 0:
234
            time_grid = time_grid[:-drop_last_times]
235
        ibs = ev.integrated_brier_score(time_grid)
236
        inbll = ev.integrated_nbll(time_grid)
237
238
        return c_index_td, ibs, inbll
239
240
241
    def compute_metrics(self, split, time_points=None):
242
        """Calculate evaluation metrics."""
243
        self._check_split_data(split)
244
245
        print('Compute evaluation metrics ... \n', end =' ')
246
        if self.patient_predictions is None:
247
            # Get all patient labels and predictions
248
            self.patient_predictions = self._collect_patient_predictions(split)
249
250
        if self.c_index is None:
251
            self.c_index = self._compute_c_index(self.patient_predictions)
252
253
        if self.c_index_td is None:
254
            td_metrics = self._compute_pycox_metrics(self.patient_predictions,
255
                                                     time_points)
256
            self.c_index_td, self.ibs, self.inbll = td_metrics
257
        print('Done.')
258
259
260
    def predict_risk(self, split):
261
        loader, _ = self._get_split_data(split)
262
        return self._predict_risk(loader)
263
264
265
    def return_results(self):
266
        assert all([
267
            self.c_index, 
268
            self.c_index_td,
269
            self.ibs,
270
            self.inbll
271
        ]),  'Results not available.' + \
272
                ' Please call "compute_metrics" or "run_bootstrap" first.'
273
274
        return (
275
            self.c_index, 
276
            self.c_index_td,
277
            self.ibs,
278
            self.inbll
279
            )
280
281
282
    
283
284