a b/code/utils/utils.py
1
import os
2
import sys
3
import re
4
import glob
5
import pickle
6
import copy
7
8
import pandas as pd
9
import numpy as np
10
import matplotlib.pyplot as plt
11
from tqdm import tqdm
12
import wfdb
13
import ast
14
from sklearn.metrics import fbeta_score, roc_auc_score, roc_curve, roc_curve, auc
15
from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer
16
from matplotlib.axes._axes import _log as matplotlib_axes_logger
17
import warnings
18
19
# EVALUATION STUFF
20
def generate_results(idxs, y_true, y_pred, thresholds):
21
    return evaluate_experiment(y_true[idxs], y_pred[idxs], thresholds)
22
23
def evaluate_experiment(y_true, y_pred, thresholds=None):
24
    results = {}
25
26
    if not thresholds is None:
27
        # binary predictions
28
        y_pred_binary = apply_thresholds(y_pred, thresholds)
29
        # PhysioNet/CinC Challenges metrics
30
        challenge_scores = challenge_metrics(y_true, y_pred_binary, beta1=2, beta2=2)
31
        results['F_beta_macro'] = challenge_scores['F_beta_macro']
32
        results['G_beta_macro'] = challenge_scores['G_beta_macro']
33
34
    # label based metric
35
    results['macro_auc'] = roc_auc_score(y_true, y_pred, average='macro')
36
    
37
    df_result = pd.DataFrame(results, index=[0])
38
    return df_result
39
40
def challenge_metrics(y_true, y_pred, beta1=2, beta2=2, class_weights=None, single=False):
41
    f_beta = 0
42
    g_beta = 0
43
    if single: # if evaluating single class in case of threshold-optimization
44
        sample_weights = np.ones(y_true.sum(axis=1).shape)
45
    else:
46
        sample_weights = y_true.sum(axis=1)
47
    for classi in range(y_true.shape[1]):
48
        y_truei, y_predi = y_true[:,classi], y_pred[:,classi]
49
        TP, FP, TN, FN = 0.,0.,0.,0.
50
        for i in range(len(y_predi)):
51
            sample_weight = sample_weights[i]
52
            if y_truei[i]==y_predi[i]==1: 
53
                TP += 1./sample_weight
54
            if ((y_predi[i]==1) and (y_truei[i]!=y_predi[i])): 
55
                FP += 1./sample_weight
56
            if y_truei[i]==y_predi[i]==0: 
57
                TN += 1./sample_weight
58
            if ((y_predi[i]==0) and (y_truei[i]!=y_predi[i])): 
59
                FN += 1./sample_weight 
60
        f_beta_i = ((1+beta1**2)*TP)/((1+beta1**2)*TP + FP + (beta1**2)*FN)
61
        g_beta_i = (TP)/(TP+FP+beta2*FN)
62
63
        f_beta += f_beta_i
64
        g_beta += g_beta_i
65
66
    return {'F_beta_macro':f_beta/y_true.shape[1], 'G_beta_macro':g_beta/y_true.shape[1]}
67
68
def get_appropriate_bootstrap_samples(y_true, n_bootstraping_samples):
69
    samples=[]
70
    while True:
71
        ridxs = np.random.randint(0, len(y_true), len(y_true))
72
        if y_true[ridxs].sum(axis=0).min() != 0:
73
            samples.append(ridxs)
74
            if len(samples) == n_bootstraping_samples:
75
                break
76
    return samples
77
78
def find_optimal_cutoff_threshold(target, predicted):
79
    """ 
80
    Find the optimal probability cutoff point for a classification model related to event rate
81
    """
82
    fpr, tpr, threshold = roc_curve(target, predicted)
83
    optimal_idx = np.argmax(tpr - fpr)
84
    optimal_threshold = threshold[optimal_idx]
85
    return optimal_threshold
86
87
def find_optimal_cutoff_thresholds(y_true, y_pred):
88
    return [find_optimal_cutoff_threshold(y_true[:,i], y_pred[:,i]) for i in range(y_true.shape[1])]
89
90
def find_optimal_cutoff_threshold_for_Gbeta(target, predicted, n_thresholds=100):
91
    thresholds = np.linspace(0.00,1,n_thresholds)
92
    scores = [challenge_metrics(target, predicted>t, single=True)['G_beta_macro'] for t in thresholds]
93
    optimal_idx = np.argmax(scores)
94
    return thresholds[optimal_idx]
95
96
def find_optimal_cutoff_thresholds_for_Gbeta(y_true, y_pred):
97
    print("optimize thresholds with respect to G_beta")
98
    return [find_optimal_cutoff_threshold_for_Gbeta(y_true[:,k][:,np.newaxis], y_pred[:,k][:,np.newaxis]) for k in tqdm(range(y_true.shape[1]))]
99
100
def apply_thresholds(preds, thresholds):
101
    """
102
        apply class-wise thresholds to prediction score in order to get binary format.
103
        BUT: if no score is above threshold, pick maximum. This is needed due to metric issues.
104
    """
105
    tmp = []
106
    for p in preds:
107
        tmp_p = (p > thresholds).astype(int)
108
        if np.sum(tmp_p) == 0:
109
            tmp_p[np.argmax(p)] = 1
110
        tmp.append(tmp_p)
111
    tmp = np.array(tmp)
112
    return tmp
113
114
# DATA PROCESSING STUFF
115
116
def load_dataset(path, sampling_rate, release=False):
117
    if path.split('/')[-2] == 'ptbxl':
118
        # load and convert annotation data
119
        Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
120
        Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
121
122
        # Load raw signal data
123
        X = load_raw_data_ptbxl(Y, sampling_rate, path)
124
125
    elif path.split('/')[-2] == 'ICBEB':
126
        # load and convert annotation data
127
        Y = pd.read_csv(path+'icbeb_database.csv', index_col='ecg_id')
128
        Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
129
130
        # Load raw signal data
131
        X = load_raw_data_icbeb(Y, sampling_rate, path)
132
133
    return X, Y
134
135
136
def load_raw_data_icbeb(df, sampling_rate, path):
137
138
    if sampling_rate == 100:
139
        if os.path.exists(path + 'raw100.npy'):
140
            data = np.load(path+'raw100.npy', allow_pickle=True)
141
        else:
142
            data = [wfdb.rdsamp(path + 'records100/'+str(f)) for f in tqdm(df.index)]
143
            data = np.array([signal for signal, meta in data])
144
            pickle.dump(data, open(path+'raw100.npy', 'wb'), protocol=4)
145
    elif sampling_rate == 500:
146
        if os.path.exists(path + 'raw500.npy'):
147
            data = np.load(path+'raw500.npy', allow_pickle=True)
148
        else:
149
            data = [wfdb.rdsamp(path + 'records500/'+str(f)) for f in tqdm(df.index)]
150
            data = np.array([signal for signal, meta in data])
151
            pickle.dump(data, open(path+'raw500.npy', 'wb'), protocol=4)
152
    return data
153
154
def load_raw_data_ptbxl(df, sampling_rate, path):
155
    if sampling_rate == 100:
156
        if os.path.exists(path + 'raw100.npy'):
157
            data = np.load(path+'raw100.npy', allow_pickle=True)
158
        else:
159
            data = [wfdb.rdsamp(path+f) for f in tqdm(df.filename_lr)]
160
            data = np.array([signal for signal, meta in data])
161
            pickle.dump(data, open(path+'raw100.npy', 'wb'), protocol=4)
162
    elif sampling_rate == 500:
163
        if os.path.exists(path + 'raw500.npy'):
164
            data = np.load(path+'raw500.npy', allow_pickle=True)
165
        else:
166
            data = [wfdb.rdsamp(path+f) for f in tqdm(df.filename_hr)]
167
            data = np.array([signal for signal, meta in data])
168
            pickle.dump(data, open(path+'raw500.npy', 'wb'), protocol=4)
169
    return data
170
171
def compute_label_aggregations(df, folder, ctype):
172
173
    df['scp_codes_len'] = df.scp_codes.apply(lambda x: len(x))
174
175
    aggregation_df = pd.read_csv(folder+'scp_statements.csv', index_col=0)
176
177
    if ctype in ['diagnostic', 'subdiagnostic', 'superdiagnostic']:
178
179
        def aggregate_all_diagnostic(y_dic):
180
            tmp = []
181
            for key in y_dic.keys():
182
                if key in diag_agg_df.index:
183
                    tmp.append(key)
184
            return list(set(tmp))
185
186
        def aggregate_subdiagnostic(y_dic):
187
            tmp = []
188
            for key in y_dic.keys():
189
                if key in diag_agg_df.index:
190
                    c = diag_agg_df.loc[key].diagnostic_subclass
191
                    if str(c) != 'nan':
192
                        tmp.append(c)
193
            return list(set(tmp))
194
195
        def aggregate_diagnostic(y_dic):
196
            tmp = []
197
            for key in y_dic.keys():
198
                if key in diag_agg_df.index:
199
                    c = diag_agg_df.loc[key].diagnostic_class
200
                    if str(c) != 'nan':
201
                        tmp.append(c)
202
            return list(set(tmp))
203
204
        diag_agg_df = aggregation_df[aggregation_df.diagnostic == 1.0]
205
        if ctype == 'diagnostic':
206
            df['diagnostic'] = df.scp_codes.apply(aggregate_all_diagnostic)
207
            df['diagnostic_len'] = df.diagnostic.apply(lambda x: len(x))
208
        elif ctype == 'subdiagnostic':
209
            df['subdiagnostic'] = df.scp_codes.apply(aggregate_subdiagnostic)
210
            df['subdiagnostic_len'] = df.subdiagnostic.apply(lambda x: len(x))
211
        elif ctype == 'superdiagnostic':
212
            df['superdiagnostic'] = df.scp_codes.apply(aggregate_diagnostic)
213
            df['superdiagnostic_len'] = df.superdiagnostic.apply(lambda x: len(x))
214
    elif ctype == 'form':
215
        form_agg_df = aggregation_df[aggregation_df.form == 1.0]
216
217
        def aggregate_form(y_dic):
218
            tmp = []
219
            for key in y_dic.keys():
220
                if key in form_agg_df.index:
221
                    c = key
222
                    if str(c) != 'nan':
223
                        tmp.append(c)
224
            return list(set(tmp))
225
226
        df['form'] = df.scp_codes.apply(aggregate_form)
227
        df['form_len'] = df.form.apply(lambda x: len(x))
228
    elif ctype == 'rhythm':
229
        rhythm_agg_df = aggregation_df[aggregation_df.rhythm == 1.0]
230
231
        def aggregate_rhythm(y_dic):
232
            tmp = []
233
            for key in y_dic.keys():
234
                if key in rhythm_agg_df.index:
235
                    c = key
236
                    if str(c) != 'nan':
237
                        tmp.append(c)
238
            return list(set(tmp))
239
240
        df['rhythm'] = df.scp_codes.apply(aggregate_rhythm)
241
        df['rhythm_len'] = df.rhythm.apply(lambda x: len(x))
242
    elif ctype == 'all':
243
        df['all_scp'] = df.scp_codes.apply(lambda x: list(set(x.keys())))
244
245
    return df
246
247
def select_data(XX,YY, ctype, min_samples, outputfolder):
248
    # convert multilabel to multi-hot
249
    mlb = MultiLabelBinarizer()
250
251
    if ctype == 'diagnostic':
252
        X = XX[YY.diagnostic_len > 0]
253
        Y = YY[YY.diagnostic_len > 0]
254
        mlb.fit(Y.diagnostic.values)
255
        y = mlb.transform(Y.diagnostic.values)
256
    elif ctype == 'subdiagnostic':
257
        counts = pd.Series(np.concatenate(YY.subdiagnostic.values)).value_counts()
258
        counts = counts[counts > min_samples]
259
        YY.subdiagnostic = YY.subdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
260
        YY['subdiagnostic_len'] = YY.subdiagnostic.apply(lambda x: len(x))
261
        X = XX[YY.subdiagnostic_len > 0]
262
        Y = YY[YY.subdiagnostic_len > 0]
263
        mlb.fit(Y.subdiagnostic.values)
264
        y = mlb.transform(Y.subdiagnostic.values)
265
    elif ctype == 'superdiagnostic':
266
        counts = pd.Series(np.concatenate(YY.superdiagnostic.values)).value_counts()
267
        counts = counts[counts > min_samples]
268
        YY.superdiagnostic = YY.superdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
269
        YY['superdiagnostic_len'] = YY.superdiagnostic.apply(lambda x: len(x))
270
        X = XX[YY.superdiagnostic_len > 0]
271
        Y = YY[YY.superdiagnostic_len > 0]
272
        mlb.fit(Y.superdiagnostic.values)
273
        y = mlb.transform(Y.superdiagnostic.values)
274
    elif ctype == 'form':
275
        # filter
276
        counts = pd.Series(np.concatenate(YY.form.values)).value_counts()
277
        counts = counts[counts > min_samples]
278
        YY.form = YY.form.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
279
        YY['form_len'] = YY.form.apply(lambda x: len(x))
280
        # select
281
        X = XX[YY.form_len > 0]
282
        Y = YY[YY.form_len > 0]
283
        mlb.fit(Y.form.values)
284
        y = mlb.transform(Y.form.values)
285
    elif ctype == 'rhythm':
286
        # filter 
287
        counts = pd.Series(np.concatenate(YY.rhythm.values)).value_counts()
288
        counts = counts[counts > min_samples]
289
        YY.rhythm = YY.rhythm.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
290
        YY['rhythm_len'] = YY.rhythm.apply(lambda x: len(x))
291
        # select
292
        X = XX[YY.rhythm_len > 0]
293
        Y = YY[YY.rhythm_len > 0]
294
        mlb.fit(Y.rhythm.values)
295
        y = mlb.transform(Y.rhythm.values)
296
    elif ctype == 'all':
297
        # filter 
298
        counts = pd.Series(np.concatenate(YY.all_scp.values)).value_counts()
299
        counts = counts[counts > min_samples]
300
        YY.all_scp = YY.all_scp.apply(lambda x: list(set(x).intersection(set(counts.index.values))))
301
        YY['all_scp_len'] = YY.all_scp.apply(lambda x: len(x))
302
        # select
303
        X = XX[YY.all_scp_len > 0]
304
        Y = YY[YY.all_scp_len > 0]
305
        mlb.fit(Y.all_scp.values)
306
        y = mlb.transform(Y.all_scp.values)
307
    else:
308
        pass
309
310
    # save LabelBinarizer
311
    with open(outputfolder+'mlb.pkl', 'wb') as tokenizer:
312
        pickle.dump(mlb, tokenizer)
313
314
    return X, Y, y, mlb
315
316
def preprocess_signals(X_train, X_validation, X_test, outputfolder):
317
    # Standardize data such that mean 0 and variance 1
318
    ss = StandardScaler()
319
    ss.fit(np.vstack(X_train).flatten()[:,np.newaxis].astype(float))
320
    
321
    # Save Standardizer data
322
    with open(outputfolder+'standard_scaler.pkl', 'wb') as ss_file:
323
        pickle.dump(ss, ss_file)
324
325
    return apply_standardizer(X_train, ss), apply_standardizer(X_validation, ss), apply_standardizer(X_test, ss)
326
327
def apply_standardizer(X, ss):
328
    X_tmp = []
329
    for x in X:
330
        x_shape = x.shape
331
        X_tmp.append(ss.transform(x.flatten()[:,np.newaxis]).reshape(x_shape))
332
    X_tmp = np.array(X_tmp)
333
    return X_tmp
334
335
336
# DOCUMENTATION STUFF
337
338
def generate_ptbxl_summary_table(selection=None, folder='../output/'):
339
340
    exps = ['exp0', 'exp1', 'exp1.1', 'exp1.1.1', 'exp2', 'exp3']
341
    metric1 = 'macro_auc'
342
343
    # get models
344
    models = {}
345
    for i, exp in enumerate(exps):
346
        if selection is None:
347
            exp_models = [m.split('/')[-1] for m in glob.glob(folder+str(exp)+'/models/*')]
348
        else:
349
            exp_models = selection
350
        if i == 0:
351
            models = set(exp_models)
352
        else:
353
            models = models.union(set(exp_models))
354
355
    results_dic = {'Method':[], 
356
                'exp0_AUC':[], 
357
                'exp1_AUC':[], 
358
                'exp1.1_AUC':[], 
359
                'exp1.1.1_AUC':[], 
360
                'exp2_AUC':[],
361
                'exp3_AUC':[]
362
                }
363
364
    for m in models:
365
        results_dic['Method'].append(m)
366
        
367
        for e in exps:
368
            
369
            try:
370
                me_res = pd.read_csv(folder+str(e)+'/models/'+str(m)+'/results/te_results.csv', index_col=0)
371
    
372
                mean1 = me_res.loc['point'][metric1]
373
                unc1 = max(me_res.loc['upper'][metric1]-me_res.loc['point'][metric1], me_res.loc['point'][metric1]-me_res.loc['lower'][metric1])
374
375
                results_dic[e+'_AUC'].append("%.3f(%.2d)" %(np.round(mean1,3), int(unc1*1000)))
376
377
            except FileNotFoundError:
378
                results_dic[e+'_AUC'].append("--")
379
            
380
            
381
    df = pd.DataFrame(results_dic)
382
    df_index = df[df.Method.isin(['naive', 'ensemble'])]
383
    df_rest = df[~df.Method.isin(['naive', 'ensemble'])]
384
    df = pd.concat([df_rest, df_index])
385
    df.to_csv(folder+'results_ptbxl.csv')
386
387
    titles = [
388
        '### 1. PTB-XL: all statements',
389
        '### 2. PTB-XL: diagnostic statements',
390
        '### 3. PTB-XL: Diagnostic subclasses',
391
        '### 4. PTB-XL: Diagnostic superclasses',
392
        '### 5. PTB-XL: Form statements',
393
        '### 6. PTB-XL: Rhythm statements'        
394
    ]
395
396
    # helper output function for markdown tables
397
    our_work = 'https://arxiv.org/abs/2004.13701'
398
    our_repo = 'https://github.com/helme/ecg_ptbxl_benchmarking/'
399
    md_source = ''
400
    for i, e in enumerate(exps):
401
        md_source += '\n '+titles[i]+' \n \n'
402
        md_source += '| Model | AUC ↓ | paper/source | code | \n'
403
        md_source += '|---:|:---|:---|:---| \n'
404
        for row in df_rest[['Method', e+'_AUC']].sort_values(e+'_AUC', ascending=False).values:
405
            md_source += '| ' + row[0].replace('fastai_', '') + ' | ' + row[1] + ' | [our work]('+our_work+') | [this repo]('+our_repo+')| \n'
406
    print(md_source)
407
408
def ICBEBE_table(selection=None, folder='../output/'):
409
    cols = ['macro_auc', 'F_beta_macro', 'G_beta_macro']
410
411
    if selection is None:
412
        models = [m.split('/')[-1].split('_pretrained')[0] for m in glob.glob(folder+'exp_ICBEB/models/*')]
413
    else:
414
        models = [] 
415
        for s in selection:
416
            #if s != 'Wavelet+NN':
417
                models.append(s)
418
419
    data = []
420
    for model in models:
421
        me_res = pd.read_csv(folder+'exp_ICBEB/models/'+model+'/results/te_results.csv', index_col=0)
422
        mcol=[]
423
        for col in cols:
424
            mean = me_res.ix['point'][col]
425
            unc = max(me_res.ix['upper'][col]-me_res.ix['point'][col], me_res.ix['point'][col]-me_res.ix['lower'][col])
426
            mcol.append("%.3f(%.2d)" %(np.round(mean,3), int(unc*1000)))
427
        data.append(mcol)
428
    data = np.array(data)
429
430
    df = pd.DataFrame(data, columns=cols, index=models)
431
    df.to_csv(folder+'results_icbeb.csv')
432
433
    df_rest = df[~df.index.isin(['naive', 'ensemble'])]
434
    df_rest = df_rest.sort_values('macro_auc', ascending=False)
435
    our_work = 'https://arxiv.org/abs/2004.13701'
436
    our_repo = 'https://github.com/helme/ecg_ptbxl_benchmarking/'
437
438
    md_source = '| Model | AUC ↓ |  F_beta=2 | G_beta=2 | paper/source | code | \n'
439
    md_source += '|---:|:---|:---|:---|:---|:---| \n'
440
    for i, row in enumerate(df_rest[cols].values):
441
        md_source += '| ' + df_rest.index[i].replace('fastai_', '') + ' | ' + row[0] + ' | ' + row[1] + ' | ' + row[2] + ' | [our work]('+our_work+') | [this repo]('+our_repo+')| \n'
442
    print(md_source)