Diff of /utils.py [000000] .. [2095ed]

Switch to unified view

a b/utils.py
1
# Base / Native
2
import math
3
import os
4
import pickle
5
import re
6
import warnings
7
warnings.filterwarnings('ignore')
8
9
# Numerical / Array
10
import lifelines
11
from lifelines.utils import concordance_index
12
from lifelines import CoxPHFitter
13
from lifelines.datasets import load_regression_dataset
14
from lifelines.utils import k_fold_cross_validation
15
from lifelines.statistics import logrank_test
16
from imblearn.over_sampling import RandomOverSampler
17
import matplotlib as mpl
18
import matplotlib.pyplot as plt
19
import matplotlib.font_manager as font_manager
20
import numpy as np
21
import pandas as pd
22
from PIL import Image
23
import pylab
24
import scipy
25
import seaborn as sns
26
from sklearn import preprocessing
27
from sklearn.model_selection import train_test_split, KFold
28
from sklearn.metrics import average_precision_score, auc, f1_score, roc_curve, roc_auc_score
29
from sklearn.preprocessing import LabelBinarizer
30
31
from scipy import interp
32
mpl.rcParams['axes.linewidth'] = 3 #set the value globally
33
34
# Torch
35
import torch
36
import torch.nn as nn
37
from torch.nn import init, Parameter
38
from torch.utils.data._utils.collate import *
39
from torch.utils.data.dataloader import default_collate
40
import torch_geometric
41
from torch_geometric.data import Batch
42
43
44
45
################
46
# Regularization
47
################
48
def regularize_weights(model, reg_type=None):
49
    l1_reg = None
50
51
    for W in model.parameters():
52
        if l1_reg is None:
53
            l1_reg = torch.abs(W).sum()
54
        else:
55
            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
56
    return l1_reg
57
58
59
def regularize_path_weights(model, reg_type=None):
60
    l1_reg = None
61
    
62
    for W in model.module.classifier.parameters():
63
        if l1_reg is None:
64
            l1_reg = torch.abs(W).sum()
65
        else:
66
            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
67
68
    for W in model.module.linear.parameters():
69
        if l1_reg is None:
70
            l1_reg = torch.abs(W).sum()
71
        else:
72
            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
73
74
    return l1_reg
75
76
77
def regularize_MM_weights(model, reg_type=None):
78
    l1_reg = None
79
80
    if model.module.__hasattr__('omic_net'):
81
        for W in model.module.omic_net.parameters():
82
            if l1_reg is None:
83
                l1_reg = torch.abs(W).sum()
84
            else:
85
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
86
87
    if model.module.__hasattr__('linear_h_path'):
88
        for W in model.module.linear_h_path.parameters():
89
            if l1_reg is None:
90
                l1_reg = torch.abs(W).sum()
91
            else:
92
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
93
94
    if model.module.__hasattr__('linear_h_omic'):
95
        for W in model.module.linear_h_omic.parameters():
96
            if l1_reg is None:
97
                l1_reg = torch.abs(W).sum()
98
            else:
99
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
100
101
    if model.module.__hasattr__('linear_h_grph'):
102
        for W in model.module.linear_h_grph.parameters():
103
            if l1_reg is None:
104
                l1_reg = torch.abs(W).sum()
105
            else:
106
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
107
108
    if model.module.__hasattr__('linear_z_path'):
109
        for W in model.module.linear_z_path.parameters():
110
            if l1_reg is None:
111
                l1_reg = torch.abs(W).sum()
112
            else:
113
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
114
115
    if model.module.__hasattr__('linear_z_omic'):
116
        for W in model.module.linear_z_omic.parameters():
117
            if l1_reg is None:
118
                l1_reg = torch.abs(W).sum()
119
            else:
120
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
121
122
    if model.module.__hasattr__('linear_z_grph'):
123
        for W in model.module.linear_z_grph.parameters():
124
            if l1_reg is None:
125
                l1_reg = torch.abs(W).sum()
126
            else:
127
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
128
129
    if model.module.__hasattr__('linear_o_path'):
130
        for W in model.module.linear_o_path.parameters():
131
            if l1_reg is None:
132
                l1_reg = torch.abs(W).sum()
133
            else:
134
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
135
136
    if model.module.__hasattr__('linear_o_omic'):
137
        for W in model.module.linear_o_omic.parameters():
138
            if l1_reg is None:
139
                l1_reg = torch.abs(W).sum()
140
            else:
141
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
142
143
    if model.module.__hasattr__('linear_o_grph'):
144
        for W in model.module.linear_o_grph.parameters():
145
            if l1_reg is None:
146
                l1_reg = torch.abs(W).sum()
147
            else:
148
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
149
150
    if model.module.__hasattr__('encoder1'):
151
        for W in model.module.encoder1.parameters():
152
            if l1_reg is None:
153
                l1_reg = torch.abs(W).sum()
154
            else:
155
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
156
157
    if model.module.__hasattr__('encoder2'):
158
        for W in model.module.encoder2.parameters():
159
            if l1_reg is None:
160
                l1_reg = torch.abs(W).sum()
161
            else:
162
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
163
164
    if model.module.__hasattr__('classifier'):
165
        for W in model.module.classifier.parameters():
166
            if l1_reg is None:
167
                l1_reg = torch.abs(W).sum()
168
            else:
169
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
170
        
171
    return l1_reg
172
173
174
def regularize_MM_omic(model, reg_type=None):
175
    l1_reg = None
176
177
    if model.module.__hasattr__('omic_net'):
178
        for W in model.module.omic_net.parameters():
179
            if l1_reg is None:
180
                l1_reg = torch.abs(W).sum()
181
            else:
182
                l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
183
184
    return l1_reg
185
186
187
188
################
189
# Network Initialization
190
################
191
def init_weights(net, init_type='orthogonal', init_gain=0.02):
192
    """Initialize network weights.
193
194
    Parameters:
195
        net (network)   -- network to be initialized
196
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
197
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
198
199
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
200
    work better for some applications. Feel free to try yourself.
201
    """
202
    def init_func(m):  # define the initialization function
203
        classname = m.__class__.__name__
204
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
205
            if init_type == 'normal':
206
                init.normal_(m.weight.data, 0.0, init_gain)
207
            elif init_type == 'xavier':
208
                init.xavier_normal_(m.weight.data, gain=init_gain)
209
            elif init_type == 'kaiming':
210
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
211
            elif init_type == 'orthogonal':
212
                init.orthogonal_(m.weight.data, gain=init_gain)
213
            else:
214
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
215
            if hasattr(m, 'bias') and m.bias is not None:
216
                init.constant_(m.bias.data, 0.0)
217
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
218
            init.normal_(m.weight.data, 1.0, init_gain)
219
            init.constant_(m.bias.data, 0.0)
220
221
    print('initialize network with %s' % init_type)
222
    net.apply(init_func)  # apply the initialization function <init_func>
223
224
225
def init_max_weights(module):
226
    for m in module.modules():
227
        if type(m) == nn.Linear:
228
            stdv = 1. / math.sqrt(m.weight.size(1))
229
            m.weight.data.normal_(0, stdv)
230
            m.bias.data.zero_()
231
232
233
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
234
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
235
    Parameters:
236
        net (network)      -- the network to be initialized
237
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
238
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
239
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
240
241
    Return an initialized network.
242
    """
243
    if len(gpu_ids) > 0:
244
        assert(torch.cuda.is_available())
245
        net.to(gpu_ids[0])
246
        net = torch.nn.DataParallel(net, gpu_ids)           # multi-GPUs
247
248
    if init_type != 'max' and init_type != 'none':
249
        print("Init Type:", init_type)
250
        init_weights(net, init_type, init_gain=init_gain)
251
    elif init_type == 'none':
252
        print("Init Type: Not initializing networks.")
253
    elif init_type == 'max':
254
        print("Init Type: Self-Normalizing Weights")
255
    return net
256
257
258
259
################
260
# Freeze / Unfreeze
261
################
262
def unfreeze_unimodal(opt, model, epoch):
263
    if opt.mode == 'graphomic':
264
        if epoch == 5:
265
            dfs_unfreeze(model.module.omic_net)
266
            print("Unfreezing Omic")
267
        if epoch == 5:
268
            dfs_unfreeze(model.module.grph_net)
269
            print("Unfreezing Graph")
270
    elif opt.mode == 'pathomic':
271
        if epoch == 5:
272
            dfs_unfreeze(model.module.omic_net)
273
            print("Unfreezing Omic")
274
    elif opt.mode == 'pathgraph':
275
        if epoch == 5:
276
            dfs_unfreeze(model.module.grph_net)
277
            print("Unfreezing Graph")
278
    elif opt.mode == "pathgraphomic":
279
        if epoch == 5:
280
            dfs_unfreeze(model.module.omic_net)
281
            print("Unfreezing Omic")
282
        if epoch == 5:
283
            dfs_unfreeze(model.module.grph_net)
284
            print("Unfreezing Graph")
285
    elif opt.mode == "omicomic":
286
        if epoch == 5:
287
            dfs_unfreeze(model.module.omic_net)
288
            print("Unfreezing Omic")
289
    elif opt.mode == "graphgraph":
290
        if epoch == 5:
291
            dfs_unfreeze(model.module.grph_net)
292
            print("Unfreezing Graph")
293
294
295
def dfs_freeze(model):
296
    for name, child in model.named_children():
297
        for param in child.parameters():
298
            param.requires_grad = False
299
        dfs_freeze(child)
300
301
302
def dfs_unfreeze(model):
303
    for name, child in model.named_children():
304
        for param in child.parameters():
305
            param.requires_grad = True
306
        dfs_unfreeze(child)
307
308
309
def print_if_frozen(module):
310
    for idx, child in enumerate(module.children()):
311
        for param in child.parameters():
312
            if param.requires_grad == True:
313
                print("Learnable!!! %d:" % idx, child)
314
            else:
315
                print("Still Frozen %d:" % idx, child)
316
317
318
def unfreeze_vgg_features(model, epoch):
319
    epoch_schedule = {30:45}
320
    unfreeze_index = epoch_schedule[epoch]
321
    for idx, child in enumerate(model.features.children()):
322
        if idx > unfreeze_index:
323
            print("Unfreezing %d:" %idx, child)
324
            for param in child.parameters(): 
325
                param.requires_grad = True
326
        else:
327
            print("Still Frozen %d:" %idx, child)
328
            continue
329
330
331
332
################
333
# Collate Utils
334
################
335
def mixed_collate(batch):
336
    elem = batch[0]
337
    elem_type = type(elem)    
338
    transposed = zip(*batch)
339
    return [Batch.from_data_list(samples, []) if type(samples[0]) is torch_geometric.data.data.Data else default_collate(samples) for samples in transposed]
340
341
342
343
################
344
# Survival Utils
345
################
346
def CoxLoss(survtime, censor, hazard_pred, device):
347
    # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
348
    # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
349
    current_batch_len = len(survtime)
350
    R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
351
    for i in range(current_batch_len):
352
        for j in range(current_batch_len):
353
            R_mat[i,j] = survtime[j] >= survtime[i]
354
355
    R_mat = torch.FloatTensor(R_mat).to(device)
356
    theta = hazard_pred.reshape(-1)
357
    exp_theta = torch.exp(theta)
358
    loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor)
359
    return loss_cox
360
361
362
def accuracy(output, labels):
363
    preds = output.max(1)[1].type_as(labels)
364
    correct = preds.eq(labels).double()
365
    correct = correct.sum()
366
    return correct / len(labels)
367
368
369
def accuracy_cox(hazardsdata, labels):
370
    # This accuracy is based on estimated survival events against true survival events
371
    median = np.median(hazardsdata)
372
    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
373
    hazards_dichotomize[hazardsdata > median] = 1
374
    correct = np.sum(hazards_dichotomize == labels)
375
    return correct / len(labels)
376
377
378
def cox_log_rank(hazardsdata, labels, survtime_all):
379
    median = np.median(hazardsdata)
380
    hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int)
381
    hazards_dichotomize[hazardsdata > median] = 1
382
    idx = hazards_dichotomize == 0
383
    T1 = survtime_all[idx]
384
    T2 = survtime_all[~idx]
385
    E1 = labels[idx]
386
    E2 = labels[~idx]
387
    results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2)
388
    pvalue_pred = results.p_value
389
    return(pvalue_pred)
390
391
392
def CIndex(hazards, labels, survtime_all):
393
    concord = 0.
394
    total = 0.
395
    N_test = labels.shape[0]
396
    for i in range(N_test):
397
        if labels[i] == 1:
398
            for j in range(N_test):
399
                if survtime_all[j] > survtime_all[i]:
400
                    total += 1
401
                    if hazards[j] < hazards[i]: concord += 1
402
                    elif hazards[j] < hazards[i]: concord += 0.5
403
404
    return(concord/total)
405
406
407
def CIndex_lifeline(hazards, labels, survtime_all):
408
    return(concordance_index(survtime_all, -hazards, labels))
409
410
411
412
################
413
# Data Utils
414
################
415
def addHistomolecularSubtype(data):
416
    """
417
    Molecular Subtype: IDHwt == 0, IDHmut-non-codel == 1, IDHmut-codel == 2
418
    Histology Subtype: astrocytoma == 0, oligoastrocytoma == 1, oligodendroglioma == 2, glioblastoma == 3
419
    """
420
    subtyped_data = data.copy()
421
    subtyped_data.insert(loc=0, column='Histomolecular subtype', value=np.ones(len(data)))
422
    idhwt_ATC = np.logical_and(data['Molecular subtype'] == 0, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
423
    subtyped_data.loc[idhwt_ATC, 'Histomolecular subtype'] = 'idhwt_ATC'
424
    
425
    idhmut_ATC = np.logical_and(data['Molecular subtype'] == 1, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
426
    subtyped_data.loc[idhmut_ATC, 'Histomolecular subtype'] = 'idhmut_ATC'
427
    
428
    ODG = np.logical_and(data['Molecular subtype'] == 2, data['Histology'] == 2)
429
    subtyped_data.loc[ODG, 'Histomolecular subtype'] = 'ODG'
430
    return subtyped_data
431
432
433
def changeHistomolecularSubtype(data):
434
    """
435
    Molecular Subtype: IDHwt == 0, IDHmut-non-codel == 1, IDHmut-codel == 2
436
    Histology Subtype: astrocytoma == 0, oligoastrocytoma == 1, oligodendroglioma == 2, glioblastoma == 3
437
    """
438
    data = data.drop(['Histomolecular subtype'], axis=1)
439
    subtyped_data = data.copy()
440
    subtyped_data.insert(loc=0, column='Histomolecular subtype', value=np.ones(len(data)))
441
    idhwt_ATC = np.logical_and(data['Molecular subtype'] == 0, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
442
    subtyped_data.loc[idhwt_ATC, 'Histomolecular subtype'] = 'idhwt_ATC'
443
    
444
    idhmut_ATC = np.logical_and(data['Molecular subtype'] == 1, np.logical_or(data['Histology'] == 0, data['Histology'] == 3))
445
    subtyped_data.loc[idhmut_ATC, 'Histomolecular subtype'] = 'idhmut_ATC'
446
    
447
    ODG = np.logical_and(data['Molecular subtype'] == 2, data['Histology'] == 2)
448
    subtyped_data.loc[ODG, 'Histomolecular subtype'] = 'ODG'
449
    return subtyped_data
450
451
452
def getCleanAllDataset(dataroot='./data/TCGA_GBMLGG/', ignore_missing_moltype=False, ignore_missing_histype=False, use_rnaseq=False):
453
    ### 1. Joining all_datasets.csv with grade data. Looks at columns with misisng samples
454
    metadata = ['Histology', 'Grade', 'Molecular subtype', 'TCGA ID', 'censored', 'Survival months']
455
    all_dataset = pd.read_csv(os.path.join(dataroot, 'all_dataset.csv')).drop('indexes', axis=1)
456
    all_dataset.index = all_dataset['TCGA ID']
457
458
    all_grade = pd.read_csv(os.path.join(dataroot, 'grade_data.csv'))
459
    all_grade['Histology'] = all_grade['Histology'].str.replace('astrocytoma (glioblastoma)', 'glioblastoma', regex=False)
460
    all_grade.index = all_grade['TCGA ID']
461
    assert pd.Series(all_dataset.index).equals(pd.Series(sorted(all_grade.index)))
462
463
    all_dataset = all_dataset.join(all_grade[['Histology', 'Grade', 'Molecular subtype']], how='inner')
464
    cols = all_dataset.columns.tolist()
465
    cols = cols[-3:] + cols[:-3]
466
    all_dataset = all_dataset[cols]
467
468
    if use_rnaseq:
469
        gbm = pd.read_csv(os.path.join(dataroot, 'mRNA_Expression_z-Scores_RNA_Seq_RSEM.txt'), sep='\t', skiprows=1, index_col=0)
470
        lgg = pd.read_csv(os.path.join(dataroot, 'mRNA_Expression_Zscores_RSEM.txt'), sep='\t', skiprows=1, index_col=0)
471
        gbm = gbm[gbm.columns[~gbm.isnull().all()]]
472
        lgg = lgg[lgg.columns[~lgg.isnull().all()]]
473
        glioma_RNAseq = gbm.join(lgg, how='inner').T
474
        glioma_RNAseq = glioma_RNAseq.dropna(axis=1)
475
        glioma_RNAseq.columns = [gene+'_rnaseq' for gene in glioma_RNAseq.columns]
476
        glioma_RNAseq.index = [patname[:12] for patname in glioma_RNAseq.index]
477
        glioma_RNAseq = glioma_RNAseq.iloc[~glioma_RNAseq.index.duplicated()]
478
        glioma_RNAseq.index.name = 'TCGA ID'
479
        all_dataset = all_dataset.join(glioma_RNAseq, how='inner')
480
481
    pat_missing_moltype = all_dataset[all_dataset['Molecular subtype'].isna()].index
482
    pat_missing_idh = all_dataset[all_dataset['idh mutation'].isna()].index
483
    pat_missing_1p19q = all_dataset[all_dataset['codeletion'].isna()].index
484
    print("# Missing Molecular Subtype:", len(pat_missing_moltype))
485
    print("# Missing IDH Mutation:", len(pat_missing_idh))
486
    print("# Missing 1p19q Codeletion:", len(pat_missing_1p19q))
487
    assert pat_missing_moltype.equals(pat_missing_idh)
488
    assert pat_missing_moltype.equals(pat_missing_1p19q)
489
    pat_missing_grade =  all_dataset[all_dataset['Grade'].isna()].index
490
    pat_missing_histype = all_dataset[all_dataset['Histology'].isna()].index
491
    print("# Missing Histological Subtype:", len(pat_missing_histype))
492
    print("# Missing Grade:", len(pat_missing_grade))
493
    assert pat_missing_histype.equals(pat_missing_grade)
494
495
    ### 2. Impute Missing Genomic Data: Removes patients with missing molecular subtype / idh mutation / 1p19q. Else imputes with median value of each column. Fills missing Molecular subtype with "Missing"
496
    if ignore_missing_moltype: 
497
        all_dataset = all_dataset[all_dataset['Molecular subtype'].isna() == False]
498
    for col in all_dataset.drop(metadata, axis=1).columns:
499
        all_dataset['Molecular subtype'] = all_dataset['Molecular subtype'].fillna('Missing')
500
        all_dataset[col] = all_dataset[col].fillna(all_dataset[col].median())
501
502
    ### 3. Impute Missing Histological Data: Removes patients with missing histological subtype / grade. Else imputes with "missing" / grade -1
503
    if ignore_missing_histype: 
504
        all_dataset = all_dataset[all_dataset['Histology'].isna() == False]
505
    else:
506
        all_dataset['Grade'] = all_dataset['Grade'].fillna(1)
507
        all_dataset['Histology'] = all_dataset['Histology'].fillna('Missing')
508
    all_dataset['Grade'] = all_dataset['Grade'] - 2
509
510
    ### 4. Adds Histomolecular subtype
511
    ms2int = {'Missing':-1, 'IDHwt':0, 'IDHmut-non-codel':1, 'IDHmut-codel':2}
512
    all_dataset[['Molecular subtype']] = all_dataset[['Molecular subtype']].applymap(lambda s: ms2int.get(s) if s in ms2int else s)
513
    hs2int = {'Missing':-1, 'astrocytoma':0, 'oligoastrocytoma':1, 'oligodendroglioma':2, 'glioblastoma':3}
514
    all_dataset[['Histology']] = all_dataset[['Histology']].applymap(lambda s: hs2int.get(s) if s in hs2int else s)
515
    all_dataset = addHistomolecularSubtype(all_dataset)
516
    metadata.extend(['Histomolecular subtype'])
517
    all_dataset['censored'] = 1 - all_dataset['censored']
518
    return metadata, all_dataset
519
520
521
522
################
523
# Analysis Utils
524
################
525
def count_parameters(model):
526
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
527
528
529
def hazard2grade(hazard, p):
530
    if hazard < p[0]:
531
        return 0
532
    elif hazard < p[1]:
533
        return 1
534
    return 2
535
536
537
def p(n):
538
    def percentile_(x):
539
        return np.percentile(x, n)
540
    percentile_.__name__ = 'p%s' % n
541
    return percentile_
542
543
544
def natural_sort(l): 
545
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
546
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
547
    return sorted(l, key = alphanum_key)
548
549
550
def CI_pm(data, confidence=0.95):
551
    a = 1.0 * np.array(data)
552
    n = len(a)
553
    m, se = np.mean(a), scipy.stats.sem(a)
554
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
555
    return str("{0:.4f} ± ".format(m) + "{0:.3f}".format(h))
556
557
558
def CI_interval(data, confidence=0.95):
559
    a = 1.0 * np.array(data)
560
    n = len(a)
561
    m, se = np.mean(a), scipy.stats.sem(a)
562
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
563
    return str("{0:.3f}, ".format(m-h) + "{0:.3f}".format(m+h))
564
565
566
def poolSurvTestPD(ckpt_name='./checkpoints/TCGA_GBMLGG/surv_15_rnaseq/', model='pathgraphomic_fusion', split='test', zscore=False, agg_type='Hazard_mean'):
567
    all_dataset_regstrd_pooled = []    
568
    ignore_missing_moltype = 1 if 'omic' in model else 0
569
    ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
570
    use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if ((('path' in model) or ('graph' in model)) and ('cox' not in model)) else ('_', 'all_st', 0)
571
    use_rnaseq = '_rnaseq' if ('rnaseq' in ckpt_name and 'path' != model and 'pathpath' not in model and 'graph' != model and 'graphgraph' not in model) else ''
572
573
    for k in range(1,16):
574
        pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
575
        
576
        if 'cox' not in model:
577
            surv_all = pd.DataFrame(np.stack(np.delete(np.array(pred), 3))).T
578
            surv_all.columns = ['Hazard', 'Survival months', 'censored', 'Grade']
579
            data_cv = pickle.load(open('./data/TCGA_GBMLGG/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (roi_dir, ignore_missing_moltype, ignore_missing_histype, use_vgg_features, use_rnaseq), 'rb'))
580
            data_cv_splits = data_cv['cv_splits']
581
            data_cv_split_k = data_cv_splits[k]
582
            assert np.all(data_cv_split_k[split]['t'] == pred[1]) # Data is correctly registered
583
            all_dataset = data_cv['data_pd'].drop('TCGA ID', axis=1)
584
            all_dataset_regstrd = all_dataset.loc[data_cv_split_k[split]['x_patname']] # Subset of "all_datasets" (metadata) that is registered with "pred" (predictions)
585
            assert np.all(np.array(all_dataset_regstrd['Survival months']) == pred[1])
586
            assert np.all(np.array(all_dataset_regstrd['censored']) == pred[2])
587
            assert np.all(np.array(all_dataset_regstrd['Grade']) == pred[4])
588
            all_dataset_regstrd.insert(loc=0, column='Hazard', value = np.array(surv_all['Hazard']))
589
            all_dataset_regstrd.index.name = 'TCGA ID'
590
            hazard_agg = all_dataset_regstrd.groupby('TCGA ID').agg({'Hazard': ['mean', 'median', max, p(0.25), p(0.75)]})
591
            hazard_agg.columns = ["_".join(x) for x in hazard_agg.columns.ravel()]
592
            hazard_agg = hazard_agg[[agg_type]]
593
            hazard_agg.columns = ['Hazard']
594
            pred = hazard_agg.join(all_dataset, how='inner')
595
596
        if zscore: pred['Hazard'] = scipy.stats.zscore(np.array(pred['Hazard']))
597
        all_dataset_regstrd_pooled.append(pred)
598
599
    all_dataset_regstrd_pooled = pd.concat(all_dataset_regstrd_pooled)
600
    all_dataset_regstrd_pooled = changeHistomolecularSubtype(all_dataset_regstrd_pooled)
601
    return all_dataset_regstrd_pooled
602
603
604
def getAggHazardCV(ckpt_name='./checkpoints/TCGA_GBMLGG/surv_15_rnaseq/', model='pathgraphomic_fusion', split='test', agg_type='Hazard_mean'):
605
    result = []
606
    
607
    ignore_missing_moltype = 1 if 'omic' in model else 0
608
    ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
609
    use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0)
610
    use_rnaseq = '_rnaseq' if ('rnaseq' in ckpt_name and 'path' != model and 'pathpath' not in model and 'graph' != model and 'graphgraph' not in model) else ''
611
612
    for k in range(1,16):
613
        pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
614
        surv_all = pd.DataFrame(np.stack(np.delete(np.array(pred), 3))).T
615
        surv_all.columns = ['Hazard', 'Survival months', 'censored', 'Grade']
616
        data_cv = pickle.load(open('./data/TCGA_GBMLGG/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (roi_dir, ignore_missing_moltype, ignore_missing_histype, use_vgg_features, use_rnaseq), 'rb'))
617
        data_cv_splits = data_cv['cv_splits']
618
        data_cv_split_k = data_cv_splits[k]
619
        assert np.all(data_cv_split_k[split]['t'] == pred[1]) # Data is correctly registered
620
        all_dataset = data_cv['data_pd'].drop('TCGA ID', axis=1)
621
        all_dataset_regstrd = all_dataset.loc[data_cv_split_k[split]['x_patname']] # Subset of "all_datasets" (metadata) that is registered with "pred" (predictions)
622
        assert np.all(np.array(all_dataset_regstrd['Survival months']) == pred[1])
623
        assert np.all(np.array(all_dataset_regstrd['censored']) == pred[2])
624
        assert np.all(np.array(all_dataset_regstrd['Grade']) == pred[4])
625
        all_dataset_regstrd.insert(loc=0, column='Hazard', value = np.array(surv_all['Hazard']))
626
        all_dataset_regstrd.index.name = 'TCGA ID'
627
        hazard_agg = all_dataset_regstrd.groupby('TCGA ID').agg({'Hazard': ['mean', max, p(0.75)]})
628
        hazard_agg.columns = ["_".join(x) for x in hazard_agg.columns.ravel()]
629
        hazard_agg = hazard_agg[[agg_type]]
630
        hazard_agg.columns = ['Hazard']
631
        all_dataset_hazard = hazard_agg.join(all_dataset, how='inner')
632
        cin = CIndex_lifeline(all_dataset_hazard['Hazard'], all_dataset_hazard['censored'], all_dataset_hazard['Survival months'])
633
        result.append(cin)
634
        
635
    return result
636
637
638
def calcGradMetrics(ckpt_name='./checkpoints/grad_15/', model='pathgraphomic_fusion', split='test', avg='micro'):
639
    auc_all = []
640
    ap_all = []
641
    f1_all = []
642
    f1_gradeIV_all = []
643
    
644
    ignore_missing_moltype = 1 if 'omic' in model else 0
645
    ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
646
    use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0)
647
    
648
    for k in range(1,16):
649
        pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
650
        grade_pred, grade = np.array(pred[3]), np.array(pred[4])
651
        enc = LabelBinarizer()
652
        enc.fit(grade)
653
        grade_oh = enc.transform(grade)
654
        rocauc = roc_auc_score(grade_oh, grade_pred, avg)
655
        ap = average_precision_score(grade_oh, grade_pred, average=avg)
656
        f1 = f1_score(grade_pred.argmax(axis=1), grade, average=avg)
657
        f1_gradeIV = f1_score(grade_pred.argmax(axis=1), grade, average=None)[2]
658
        
659
        auc_all.append(rocauc)
660
        ap_all.append(ap)
661
        f1_all.append(f1)
662
        f1_gradeIV_all.append(f1_gradeIV)
663
        
664
    return np.array([CI_pm(auc_all), CI_pm(ap_all), CI_pm(f1_all), CI_pm(f1_gradeIV_all)])
665
666
667
668
################
669
# Plot Utils
670
################
671
def makeKaplanMeierPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='omic', split='test', zscore=False, agg_type='Hazard_mean'):
672
    def hazard2KMCurve(data, subtype):
673
        p = np.percentile(data['Hazard'], [33, 66])
674
        if p[0] == p[1]: p[0] = 2.99997
675
        data.insert(0, 'grade_pred', [hazard2grade(hazard, p) for hazard in data['Hazard']])
676
        kmf_pred = lifelines.KaplanMeierFitter()
677
        kmf_gt = lifelines.KaplanMeierFitter()
678
679
        def get_name(model):
680
            mode2name = {'pathgraphomic':'Pathomic F.', 'pathomic':'Pathomic F.', 'graphomic':'Pathomic F.', 'path':'Histology CNN', 'graph':'Histology GCN', 'omic':'Genomic SNN'}
681
            for mode in mode2name.keys():
682
                if mode in model: return mode2name[mode]
683
            return 'N/A'
684
685
        fig = plt.figure(figsize=(10, 10), dpi=600)
686
        ax = plt.subplot()
687
        censor_style = {'ms': 20, 'marker': '+'}
688
        
689
        temp = data[data['Grade']==0]
690
        kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade II")
691
        kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='g', linewidth=3, ls='--', markerfacecolor='black', censor_styles=censor_style)
692
        temp = data[data['grade_pred']==0]
693
        kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (Low)" % get_name(model))
694
        kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='g', linewidth=4, ls='-', markerfacecolor='black', censor_styles=censor_style)
695
696
        temp = data[data['Grade']==1]
697
        kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade III")
698
        kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='b', linewidth=3, ls='--', censor_styles=censor_style)
699
        temp = data[data['grade_pred']==1]
700
        kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (Mid)" % get_name(model))
701
        kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='b', linewidth=4, ls='-', censor_styles=censor_style)
702
703
        if subtype != 'ODG':    
704
            temp = data[data['Grade']==2]
705
            kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade IV")
706
            kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='r', linewidth=3, ls='--', censor_styles=censor_style)
707
            temp = data[data['grade_pred']==2]
708
            kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (High)" % get_name(model))
709
            kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='r', linewidth=4, ls='-', censor_styles=censor_style)
710
711
        ax.set_xlabel('')
712
        ax.set_ylim(0, 1)
713
        ax.set_yticks(np.arange(0, 1.001, 0.5))
714
715
        ax.tick_params(axis='both', which='major', labelsize=40)    
716
        plt.legend(fontsize=32, prop=font_manager.FontProperties(family='Arial', style='normal', size=32))
717
        if subtype != 'idhwt_ATC': ax.get_legend().remove()
718
        return fig
719
    
720
    data = poolSurvTestPD(ckpt_name, model, split, zscore, agg_type)
721
    for subtype in ['idhwt_ATC', 'idhmut_ATC', 'ODG']:
722
        fig = hazard2KMCurve(data[data['Histomolecular subtype'] == subtype], subtype)
723
        fig.savefig(ckpt_name+'/%s_KM_%s.png' % (model, subtype))
724
        
725
    fig = hazard2KMCurve(data, 'all')
726
    fig.savefig(ckpt_name+'/%s_KM_%s.png' % (model, 'all'))
727
728
729
def makeHazardSwarmPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='path', split='test', zscore=True, agg_type='Hazard_mean'):
730
    mpl.rcParams['font.family'] = "arial"
731
    data = poolSurvTestPD(ckpt_name=ckpt_name, model=model, split=split, zscore=zscore, agg_type=agg_type)
732
    data = data[data['Grade'] != -1]
733
    data = data[data['Histomolecular subtype'] != -1]
734
    data['Grade'] = data['Grade'].astype(int).astype(str)
735
    data['Grade'] = data['Grade'].str.replace('0', 'Grade II', regex=False)
736
    data['Grade'] = data['Grade'].str.replace('1', 'Grade III', regex=False)
737
    data['Grade'] = data['Grade'].str.replace('2', 'Grade IV', regex=False)
738
    data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('idhwt_ATC', 'IDH-wt \n astryocytoma', regex=False)
739
    data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('idhmut_ATC', 'IDH-mut \n astrocytoma', regex=False)
740
    data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('ODG', 'Oligodendroglioma', regex=False)
741
742
    fig, ax = plt.subplots(dpi=600)
743
    ax.set_ylim([-2, 2.5]) # plt.ylim(-2, 2)
744
    ax.spines['right'].set_visible(False)
745
    ax.spines['top'].set_visible(False)
746
    ax.set_yticks(np.arange(-2, 2.001, 1))
747
    
748
    sns.swarmplot(x = 'Histomolecular subtype', y='Hazard', data=data, hue='Grade',
749
                  palette={"Grade II":"#AFD275" , "Grade III":"#7395AE", "Grade IV":"#E7717D"}, 
750
                  size = 4, alpha = 0.9, ax=ax)
751
    
752
    ax.set_xlabel('') # ax.set_xlabel('Histomolecular subtype', size=16)
753
    ax.set_ylabel('') # ax.set_ylabel('Hazard (Z-Score)', size=16)
754
    ax.tick_params(axis='y', which='both', labelsize=20)
755
    ax.tick_params(axis='x', which='both', labelsize=15)
756
    ax.tick_params(axis='x', which='both', labelbottom='off') # doesn't work??
757
    ax.legend(prop={'size': 8})
758
    fig.savefig(ckpt_name+'/%s_HSP.png' % (model))
759
760
761
def makeHazardBoxPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='omic', split='test', zscore=True, agg_type='Hazard_mean'):
762
    mpl.rcParams['font.family'] = "arial"
763
    data = poolSurvTestPD(ckpt_name, model, split, zscore, 'Hazard_mean')
764
    data['Grade'] = data['Grade'].astype(int).astype(str)
765
    data['Grade'] = data['Grade'].str.replace('0', 'II', regex=False)
766
    data['Grade'] = data['Grade'].str.replace('1', 'III', regex=False)
767
    data['Grade'] = data['Grade'].str.replace('2', 'IV', regex=False)
768
    
769
    fig, axes = plt.subplots(nrows=1, ncols=3, gridspec_kw={'width_ratios': [3, 3, 2]}, dpi=600)
770
    plt.subplots_adjust(wspace=0, hspace=0)
771
    plt.ylim(-2, 2)
772
    plt.yticks(np.arange(-2, 2.001, 1))
773
    #color_dict = {0: '#CF9498', 1: '#8CC7C8', 2: '#AAA0C6'}
774
    #color_dict = {0: '#F76C6C', 1: '#A8D0E6', 2: '#F8E9A1'}
775
    color_dict = ['#F76C6C', '#A8D0E6', '#F8E9A1']
776
    subtypes = ['idhwt_ATC', 'idhmut_ATC', 'ODG']
777
778
    for i in range(len(subtypes)):
779
        axes[i].spines["top"].set_visible(False)
780
        axes[i].spines["right"].set_visible(False)
781
        axes[i].xaxis.grid(False)
782
        axes[i].yaxis.grid(False)
783
        
784
        if i > 0: 
785
            axes[i].get_yaxis().set_visible(False)
786
            axes[i].spines["left"].set_visible(False)
787
            
788
        order = ["II","III","IV"] if subtypes[i] != 'ODG' else ["II", "III"]
789
        
790
        axes[i].xaxis.label.set_visible(False)
791
        axes[i].yaxis.label.set_visible(False)
792
        axes[i].tick_params(axis='y', which='both', labelsize=20)
793
        axes[i].tick_params(axis='x', which='both', labelsize=15)
794
        datapoints = data[data['Histomolecular subtype'] == subtypes[i]]
795
        sns.boxplot(y='Hazard', x="Grade", data=datapoints, ax = axes[i], color=color_dict[i], order=order)
796
        sns.stripplot(y='Hazard', x='Grade', data=datapoints, alpha=0.2, jitter=0.2, color='k', ax = axes[i], order=order)
797
        axes[i].set_ylim(-2.5, 2.5)
798
        axes[i].set_yticks(np.arange(-2.0, 2.1, 1))
799
        
800
    #axes[2].legend(prop={'size': 10})
801
    fig.savefig(ckpt_name+'/%s_HBP.png' % (model))
802
803
804
def makeAUROCPlot(ckpt_name='./checkpoints/grad_15/', model_list=['path', 'omic', 'pathgraphomic_fusion'], split='test', avg='micro', use_zoom=False):
805
    mpl.rcParams['font.family'] = "arial"
806
    colors = {'path':'dodgerblue', 'graph':'orange', 'omic':'green', 'pathgraphomic_fusion':'crimson'}
807
    names = {'path':'Histology CNN', 'graph':'Histology GCN', 'omic':'Genomic SNN', 'pathgraphomic_fusion':'Pathomic F.'}
808
    zoom_params = {0:([0.2, 0.4], [0.8, 1.0]), 
809
                   1:([0.25, 0.45], [0.75, 0.95]),
810
                   2:([0.0, 0.2], [0.8, 1.0]),
811
                   'micro':([0.15, 0.35], [0.8, 1.0])}
812
    mean_fpr = np.linspace(0, 1, 100)
813
    classes = [0, 1, 2, avg]
814
    ### 1. Looping over classes
815
    for i in classes:
816
        print("Class: " + str(i))
817
        fi = pylab.figure(figsize=(10,10), dpi=600, linewidth=0.2)
818
        axi = plt.subplot()
819
        
820
        ### 2. Looping over models
821
        for m, model in enumerate(model_list):
822
            ignore_missing_moltype = 1 if 'omic' in model else 0
823
            ignore_missing_histype = 1 if 'grad' in ckpt_name else 0
824
            use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0)
825
826
            ###. 3. Looping over all splits
827
            tprs, pres, aucrocs, rocaucs, = [], [], [], []
828
            for k in range(1,16):
829
                pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb'))    
830
                grade_pred, grade = np.array(pred[3]), np.array(pred[4])
831
                enc = LabelBinarizer()
832
                enc.fit(grade)
833
                grade_oh = enc.transform(grade)
834
835
                if i != avg:
836
                    pres.append(average_precision_score(grade_oh[:, i], grade_pred[:, i])) # from https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
837
                    fpr, tpr, thresh = roc_curve(grade_oh[:,i], grade_pred[:,i], drop_intermediate=False)
838
                    aucrocs.append(auc(fpr, tpr)) # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
839
                    rocaucs.append(roc_auc_score(grade_oh[:,i], grade_pred[:,i])) # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
840
                    tprs.append(interp(mean_fpr, fpr, tpr))
841
                    tprs[-1][0] = 0.0
842
                else:
843
                    # A "micro-average": quantifying score on all classes jointly
844
                    pres.append(average_precision_score(grade_oh, grade_pred, average=avg))
845
                    fpr, tpr, thresh = roc_curve(grade_oh.ravel(), grade_pred.ravel())
846
                    aucrocs.append(auc(fpr, tpr))
847
                    rocaucs.append(roc_auc_score(grade_oh, grade_pred, avg))
848
                    tprs.append(interp(mean_fpr, fpr, tpr))
849
                    tprs[-1][0] = 0.0
850
851
            mean_tpr = np.mean(tprs, axis=0)
852
            mean_tpr[-1] = 1.0
853
            #mean_auc = auc(mean_fpr, mean_tpr)
854
            mean_auc = np.mean(aucrocs)
855
            std_auc = np.std(aucrocs)
856
            print('\t'+'%s - AUC: %0.3f ± %0.3f' % (model, mean_auc, std_auc))
857
            
858
            if use_zoom:
859
                alpha, lw = (0.8, 6) if model =='pathgraphomic_fusion' else (0.5, 6)
860
                plt.plot(mean_fpr, mean_tpr, color=colors[model],
861
                     label=r'%s (AUC = %0.3f $\pm$ %0.3f)' % (names[model], mean_auc, std_auc), lw=lw, alpha=alpha)
862
                std_tpr = np.std(tprs, axis=0)
863
                tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
864
                tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
865
                plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model], alpha=0.1)
866
                plt.xlim([zoom_params[i][0][0]-0.005, zoom_params[i][0][1]+0.005])
867
                plt.ylim([zoom_params[i][1][0]-0.005, zoom_params[i][1][1]+0.005])
868
                axi.set_xticks(np.arange(zoom_params[i][0][0], zoom_params[i][0][1]+0.001, 0.05))
869
                axi.set_yticks(np.arange(zoom_params[i][1][0], zoom_params[i][1][1]+0.001, 0.05))
870
                axi.tick_params(axis='both', which='major', labelsize=26)
871
            else:
872
                alpha, lw = (0.8, 4) if model =='pathgraphomic_fusion' else (0.5, 3)
873
                plt.plot(mean_fpr, mean_tpr, color=colors[model],
874
                     label=r'%s (AUC = %0.3f $\pm$ %0.3f)' % (names[model], mean_auc, std_auc), lw=lw, alpha=alpha)
875
                std_tpr = np.std(tprs, axis=0)
876
                tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
877
                tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
878
                plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model], alpha=0.1)
879
                plt.xlim([-0.05, 1.05])
880
                plt.ylim([-0.05, 1.05])
881
                axi.set_xticks(np.arange(0, 1.001, 0.2))
882
                axi.set_yticks(np.arange(0, 1.001, 0.2))
883
                axi.legend(loc="lower right", prop={'size': 20})
884
                axi.tick_params(axis='both', which='major', labelsize=30)
885
                #plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='navy', alpha=.8)
886
887
    figures = [manager.canvas.figure
888
               for manager in mpl._pylab_helpers.Gcf.get_all_fig_managers()]
889
    
890
    zoom = '_zoom' if use_zoom else ''
891
    for i, fig in enumerate(figures):
892
        fig.savefig(ckpt_name+'/AUC_%s%s.png' % (classes[i], zoom))