|
a |
|
b/utils.py |
|
|
1 |
# Base / Native |
|
|
2 |
import math |
|
|
3 |
import os |
|
|
4 |
import pickle |
|
|
5 |
import re |
|
|
6 |
import warnings |
|
|
7 |
warnings.filterwarnings('ignore') |
|
|
8 |
|
|
|
9 |
# Numerical / Array |
|
|
10 |
import lifelines |
|
|
11 |
from lifelines.utils import concordance_index |
|
|
12 |
from lifelines import CoxPHFitter |
|
|
13 |
from lifelines.datasets import load_regression_dataset |
|
|
14 |
from lifelines.utils import k_fold_cross_validation |
|
|
15 |
from lifelines.statistics import logrank_test |
|
|
16 |
from imblearn.over_sampling import RandomOverSampler |
|
|
17 |
import matplotlib as mpl |
|
|
18 |
import matplotlib.pyplot as plt |
|
|
19 |
import matplotlib.font_manager as font_manager |
|
|
20 |
import numpy as np |
|
|
21 |
import pandas as pd |
|
|
22 |
from PIL import Image |
|
|
23 |
import pylab |
|
|
24 |
import scipy |
|
|
25 |
import seaborn as sns |
|
|
26 |
from sklearn import preprocessing |
|
|
27 |
from sklearn.model_selection import train_test_split, KFold |
|
|
28 |
from sklearn.metrics import average_precision_score, auc, f1_score, roc_curve, roc_auc_score |
|
|
29 |
from sklearn.preprocessing import LabelBinarizer |
|
|
30 |
|
|
|
31 |
from scipy import interp |
|
|
32 |
mpl.rcParams['axes.linewidth'] = 3 #set the value globally |
|
|
33 |
|
|
|
34 |
# Torch |
|
|
35 |
import torch |
|
|
36 |
import torch.nn as nn |
|
|
37 |
from torch.nn import init, Parameter |
|
|
38 |
from torch.utils.data._utils.collate import * |
|
|
39 |
from torch.utils.data.dataloader import default_collate |
|
|
40 |
import torch_geometric |
|
|
41 |
from torch_geometric.data import Batch |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
|
|
|
45 |
################ |
|
|
46 |
# Regularization |
|
|
47 |
################ |
|
|
48 |
def regularize_weights(model, reg_type=None): |
|
|
49 |
l1_reg = None |
|
|
50 |
|
|
|
51 |
for W in model.parameters(): |
|
|
52 |
if l1_reg is None: |
|
|
53 |
l1_reg = torch.abs(W).sum() |
|
|
54 |
else: |
|
|
55 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
56 |
return l1_reg |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
def regularize_path_weights(model, reg_type=None): |
|
|
60 |
l1_reg = None |
|
|
61 |
|
|
|
62 |
for W in model.module.classifier.parameters(): |
|
|
63 |
if l1_reg is None: |
|
|
64 |
l1_reg = torch.abs(W).sum() |
|
|
65 |
else: |
|
|
66 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
67 |
|
|
|
68 |
for W in model.module.linear.parameters(): |
|
|
69 |
if l1_reg is None: |
|
|
70 |
l1_reg = torch.abs(W).sum() |
|
|
71 |
else: |
|
|
72 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
73 |
|
|
|
74 |
return l1_reg |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
def regularize_MM_weights(model, reg_type=None): |
|
|
78 |
l1_reg = None |
|
|
79 |
|
|
|
80 |
if model.module.__hasattr__('omic_net'): |
|
|
81 |
for W in model.module.omic_net.parameters(): |
|
|
82 |
if l1_reg is None: |
|
|
83 |
l1_reg = torch.abs(W).sum() |
|
|
84 |
else: |
|
|
85 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
86 |
|
|
|
87 |
if model.module.__hasattr__('linear_h_path'): |
|
|
88 |
for W in model.module.linear_h_path.parameters(): |
|
|
89 |
if l1_reg is None: |
|
|
90 |
l1_reg = torch.abs(W).sum() |
|
|
91 |
else: |
|
|
92 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
93 |
|
|
|
94 |
if model.module.__hasattr__('linear_h_omic'): |
|
|
95 |
for W in model.module.linear_h_omic.parameters(): |
|
|
96 |
if l1_reg is None: |
|
|
97 |
l1_reg = torch.abs(W).sum() |
|
|
98 |
else: |
|
|
99 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
100 |
|
|
|
101 |
if model.module.__hasattr__('linear_h_grph'): |
|
|
102 |
for W in model.module.linear_h_grph.parameters(): |
|
|
103 |
if l1_reg is None: |
|
|
104 |
l1_reg = torch.abs(W).sum() |
|
|
105 |
else: |
|
|
106 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
107 |
|
|
|
108 |
if model.module.__hasattr__('linear_z_path'): |
|
|
109 |
for W in model.module.linear_z_path.parameters(): |
|
|
110 |
if l1_reg is None: |
|
|
111 |
l1_reg = torch.abs(W).sum() |
|
|
112 |
else: |
|
|
113 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
114 |
|
|
|
115 |
if model.module.__hasattr__('linear_z_omic'): |
|
|
116 |
for W in model.module.linear_z_omic.parameters(): |
|
|
117 |
if l1_reg is None: |
|
|
118 |
l1_reg = torch.abs(W).sum() |
|
|
119 |
else: |
|
|
120 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
121 |
|
|
|
122 |
if model.module.__hasattr__('linear_z_grph'): |
|
|
123 |
for W in model.module.linear_z_grph.parameters(): |
|
|
124 |
if l1_reg is None: |
|
|
125 |
l1_reg = torch.abs(W).sum() |
|
|
126 |
else: |
|
|
127 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
128 |
|
|
|
129 |
if model.module.__hasattr__('linear_o_path'): |
|
|
130 |
for W in model.module.linear_o_path.parameters(): |
|
|
131 |
if l1_reg is None: |
|
|
132 |
l1_reg = torch.abs(W).sum() |
|
|
133 |
else: |
|
|
134 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
135 |
|
|
|
136 |
if model.module.__hasattr__('linear_o_omic'): |
|
|
137 |
for W in model.module.linear_o_omic.parameters(): |
|
|
138 |
if l1_reg is None: |
|
|
139 |
l1_reg = torch.abs(W).sum() |
|
|
140 |
else: |
|
|
141 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
142 |
|
|
|
143 |
if model.module.__hasattr__('linear_o_grph'): |
|
|
144 |
for W in model.module.linear_o_grph.parameters(): |
|
|
145 |
if l1_reg is None: |
|
|
146 |
l1_reg = torch.abs(W).sum() |
|
|
147 |
else: |
|
|
148 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
149 |
|
|
|
150 |
if model.module.__hasattr__('encoder1'): |
|
|
151 |
for W in model.module.encoder1.parameters(): |
|
|
152 |
if l1_reg is None: |
|
|
153 |
l1_reg = torch.abs(W).sum() |
|
|
154 |
else: |
|
|
155 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
156 |
|
|
|
157 |
if model.module.__hasattr__('encoder2'): |
|
|
158 |
for W in model.module.encoder2.parameters(): |
|
|
159 |
if l1_reg is None: |
|
|
160 |
l1_reg = torch.abs(W).sum() |
|
|
161 |
else: |
|
|
162 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
163 |
|
|
|
164 |
if model.module.__hasattr__('classifier'): |
|
|
165 |
for W in model.module.classifier.parameters(): |
|
|
166 |
if l1_reg is None: |
|
|
167 |
l1_reg = torch.abs(W).sum() |
|
|
168 |
else: |
|
|
169 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
170 |
|
|
|
171 |
return l1_reg |
|
|
172 |
|
|
|
173 |
|
|
|
174 |
def regularize_MM_omic(model, reg_type=None): |
|
|
175 |
l1_reg = None |
|
|
176 |
|
|
|
177 |
if model.module.__hasattr__('omic_net'): |
|
|
178 |
for W in model.module.omic_net.parameters(): |
|
|
179 |
if l1_reg is None: |
|
|
180 |
l1_reg = torch.abs(W).sum() |
|
|
181 |
else: |
|
|
182 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
183 |
|
|
|
184 |
return l1_reg |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
|
|
|
188 |
################ |
|
|
189 |
# Network Initialization |
|
|
190 |
################ |
|
|
191 |
def init_weights(net, init_type='orthogonal', init_gain=0.02): |
|
|
192 |
"""Initialize network weights. |
|
|
193 |
|
|
|
194 |
Parameters: |
|
|
195 |
net (network) -- network to be initialized |
|
|
196 |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
|
197 |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
198 |
|
|
|
199 |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
|
200 |
work better for some applications. Feel free to try yourself. |
|
|
201 |
""" |
|
|
202 |
def init_func(m): # define the initialization function |
|
|
203 |
classname = m.__class__.__name__ |
|
|
204 |
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
|
205 |
if init_type == 'normal': |
|
|
206 |
init.normal_(m.weight.data, 0.0, init_gain) |
|
|
207 |
elif init_type == 'xavier': |
|
|
208 |
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
|
209 |
elif init_type == 'kaiming': |
|
|
210 |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
|
211 |
elif init_type == 'orthogonal': |
|
|
212 |
init.orthogonal_(m.weight.data, gain=init_gain) |
|
|
213 |
else: |
|
|
214 |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
|
215 |
if hasattr(m, 'bias') and m.bias is not None: |
|
|
216 |
init.constant_(m.bias.data, 0.0) |
|
|
217 |
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. |
|
|
218 |
init.normal_(m.weight.data, 1.0, init_gain) |
|
|
219 |
init.constant_(m.bias.data, 0.0) |
|
|
220 |
|
|
|
221 |
print('initialize network with %s' % init_type) |
|
|
222 |
net.apply(init_func) # apply the initialization function <init_func> |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
def init_max_weights(module): |
|
|
226 |
for m in module.modules(): |
|
|
227 |
if type(m) == nn.Linear: |
|
|
228 |
stdv = 1. / math.sqrt(m.weight.size(1)) |
|
|
229 |
m.weight.data.normal_(0, stdv) |
|
|
230 |
m.bias.data.zero_() |
|
|
231 |
|
|
|
232 |
|
|
|
233 |
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): |
|
|
234 |
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights |
|
|
235 |
Parameters: |
|
|
236 |
net (network) -- the network to be initialized |
|
|
237 |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
|
238 |
gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
239 |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 |
|
|
240 |
|
|
|
241 |
Return an initialized network. |
|
|
242 |
""" |
|
|
243 |
if len(gpu_ids) > 0: |
|
|
244 |
assert(torch.cuda.is_available()) |
|
|
245 |
net.to(gpu_ids[0]) |
|
|
246 |
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs |
|
|
247 |
|
|
|
248 |
if init_type != 'max' and init_type != 'none': |
|
|
249 |
print("Init Type:", init_type) |
|
|
250 |
init_weights(net, init_type, init_gain=init_gain) |
|
|
251 |
elif init_type == 'none': |
|
|
252 |
print("Init Type: Not initializing networks.") |
|
|
253 |
elif init_type == 'max': |
|
|
254 |
print("Init Type: Self-Normalizing Weights") |
|
|
255 |
return net |
|
|
256 |
|
|
|
257 |
|
|
|
258 |
|
|
|
259 |
################ |
|
|
260 |
# Freeze / Unfreeze |
|
|
261 |
################ |
|
|
262 |
def unfreeze_unimodal(opt, model, epoch): |
|
|
263 |
if opt.mode == 'graphomic': |
|
|
264 |
if epoch == 5: |
|
|
265 |
dfs_unfreeze(model.module.omic_net) |
|
|
266 |
print("Unfreezing Omic") |
|
|
267 |
if epoch == 5: |
|
|
268 |
dfs_unfreeze(model.module.grph_net) |
|
|
269 |
print("Unfreezing Graph") |
|
|
270 |
elif opt.mode == 'pathomic': |
|
|
271 |
if epoch == 5: |
|
|
272 |
dfs_unfreeze(model.module.omic_net) |
|
|
273 |
print("Unfreezing Omic") |
|
|
274 |
elif opt.mode == 'pathgraph': |
|
|
275 |
if epoch == 5: |
|
|
276 |
dfs_unfreeze(model.module.grph_net) |
|
|
277 |
print("Unfreezing Graph") |
|
|
278 |
elif opt.mode == "pathgraphomic": |
|
|
279 |
if epoch == 5: |
|
|
280 |
dfs_unfreeze(model.module.omic_net) |
|
|
281 |
print("Unfreezing Omic") |
|
|
282 |
if epoch == 5: |
|
|
283 |
dfs_unfreeze(model.module.grph_net) |
|
|
284 |
print("Unfreezing Graph") |
|
|
285 |
elif opt.mode == "omicomic": |
|
|
286 |
if epoch == 5: |
|
|
287 |
dfs_unfreeze(model.module.omic_net) |
|
|
288 |
print("Unfreezing Omic") |
|
|
289 |
elif opt.mode == "graphgraph": |
|
|
290 |
if epoch == 5: |
|
|
291 |
dfs_unfreeze(model.module.grph_net) |
|
|
292 |
print("Unfreezing Graph") |
|
|
293 |
|
|
|
294 |
|
|
|
295 |
def dfs_freeze(model): |
|
|
296 |
for name, child in model.named_children(): |
|
|
297 |
for param in child.parameters(): |
|
|
298 |
param.requires_grad = False |
|
|
299 |
dfs_freeze(child) |
|
|
300 |
|
|
|
301 |
|
|
|
302 |
def dfs_unfreeze(model): |
|
|
303 |
for name, child in model.named_children(): |
|
|
304 |
for param in child.parameters(): |
|
|
305 |
param.requires_grad = True |
|
|
306 |
dfs_unfreeze(child) |
|
|
307 |
|
|
|
308 |
|
|
|
309 |
def print_if_frozen(module): |
|
|
310 |
for idx, child in enumerate(module.children()): |
|
|
311 |
for param in child.parameters(): |
|
|
312 |
if param.requires_grad == True: |
|
|
313 |
print("Learnable!!! %d:" % idx, child) |
|
|
314 |
else: |
|
|
315 |
print("Still Frozen %d:" % idx, child) |
|
|
316 |
|
|
|
317 |
|
|
|
318 |
def unfreeze_vgg_features(model, epoch): |
|
|
319 |
epoch_schedule = {30:45} |
|
|
320 |
unfreeze_index = epoch_schedule[epoch] |
|
|
321 |
for idx, child in enumerate(model.features.children()): |
|
|
322 |
if idx > unfreeze_index: |
|
|
323 |
print("Unfreezing %d:" %idx, child) |
|
|
324 |
for param in child.parameters(): |
|
|
325 |
param.requires_grad = True |
|
|
326 |
else: |
|
|
327 |
print("Still Frozen %d:" %idx, child) |
|
|
328 |
continue |
|
|
329 |
|
|
|
330 |
|
|
|
331 |
|
|
|
332 |
################ |
|
|
333 |
# Collate Utils |
|
|
334 |
################ |
|
|
335 |
def mixed_collate(batch): |
|
|
336 |
elem = batch[0] |
|
|
337 |
elem_type = type(elem) |
|
|
338 |
transposed = zip(*batch) |
|
|
339 |
return [Batch.from_data_list(samples, []) if type(samples[0]) is torch_geometric.data.data.Data else default_collate(samples) for samples in transposed] |
|
|
340 |
|
|
|
341 |
|
|
|
342 |
|
|
|
343 |
################ |
|
|
344 |
# Survival Utils |
|
|
345 |
################ |
|
|
346 |
def CoxLoss(survtime, censor, hazard_pred, device): |
|
|
347 |
# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet |
|
|
348 |
# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data |
|
|
349 |
current_batch_len = len(survtime) |
|
|
350 |
R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int) |
|
|
351 |
for i in range(current_batch_len): |
|
|
352 |
for j in range(current_batch_len): |
|
|
353 |
R_mat[i,j] = survtime[j] >= survtime[i] |
|
|
354 |
|
|
|
355 |
R_mat = torch.FloatTensor(R_mat).to(device) |
|
|
356 |
theta = hazard_pred.reshape(-1) |
|
|
357 |
exp_theta = torch.exp(theta) |
|
|
358 |
loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * censor) |
|
|
359 |
return loss_cox |
|
|
360 |
|
|
|
361 |
|
|
|
362 |
def accuracy(output, labels): |
|
|
363 |
preds = output.max(1)[1].type_as(labels) |
|
|
364 |
correct = preds.eq(labels).double() |
|
|
365 |
correct = correct.sum() |
|
|
366 |
return correct / len(labels) |
|
|
367 |
|
|
|
368 |
|
|
|
369 |
def accuracy_cox(hazardsdata, labels): |
|
|
370 |
# This accuracy is based on estimated survival events against true survival events |
|
|
371 |
median = np.median(hazardsdata) |
|
|
372 |
hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int) |
|
|
373 |
hazards_dichotomize[hazardsdata > median] = 1 |
|
|
374 |
correct = np.sum(hazards_dichotomize == labels) |
|
|
375 |
return correct / len(labels) |
|
|
376 |
|
|
|
377 |
|
|
|
378 |
def cox_log_rank(hazardsdata, labels, survtime_all): |
|
|
379 |
median = np.median(hazardsdata) |
|
|
380 |
hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int) |
|
|
381 |
hazards_dichotomize[hazardsdata > median] = 1 |
|
|
382 |
idx = hazards_dichotomize == 0 |
|
|
383 |
T1 = survtime_all[idx] |
|
|
384 |
T2 = survtime_all[~idx] |
|
|
385 |
E1 = labels[idx] |
|
|
386 |
E2 = labels[~idx] |
|
|
387 |
results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2) |
|
|
388 |
pvalue_pred = results.p_value |
|
|
389 |
return(pvalue_pred) |
|
|
390 |
|
|
|
391 |
|
|
|
392 |
def CIndex(hazards, labels, survtime_all): |
|
|
393 |
concord = 0. |
|
|
394 |
total = 0. |
|
|
395 |
N_test = labels.shape[0] |
|
|
396 |
for i in range(N_test): |
|
|
397 |
if labels[i] == 1: |
|
|
398 |
for j in range(N_test): |
|
|
399 |
if survtime_all[j] > survtime_all[i]: |
|
|
400 |
total += 1 |
|
|
401 |
if hazards[j] < hazards[i]: concord += 1 |
|
|
402 |
elif hazards[j] < hazards[i]: concord += 0.5 |
|
|
403 |
|
|
|
404 |
return(concord/total) |
|
|
405 |
|
|
|
406 |
|
|
|
407 |
def CIndex_lifeline(hazards, labels, survtime_all): |
|
|
408 |
return(concordance_index(survtime_all, -hazards, labels)) |
|
|
409 |
|
|
|
410 |
|
|
|
411 |
|
|
|
412 |
################ |
|
|
413 |
# Data Utils |
|
|
414 |
################ |
|
|
415 |
def addHistomolecularSubtype(data): |
|
|
416 |
""" |
|
|
417 |
Molecular Subtype: IDHwt == 0, IDHmut-non-codel == 1, IDHmut-codel == 2 |
|
|
418 |
Histology Subtype: astrocytoma == 0, oligoastrocytoma == 1, oligodendroglioma == 2, glioblastoma == 3 |
|
|
419 |
""" |
|
|
420 |
subtyped_data = data.copy() |
|
|
421 |
subtyped_data.insert(loc=0, column='Histomolecular subtype', value=np.ones(len(data))) |
|
|
422 |
idhwt_ATC = np.logical_and(data['Molecular subtype'] == 0, np.logical_or(data['Histology'] == 0, data['Histology'] == 3)) |
|
|
423 |
subtyped_data.loc[idhwt_ATC, 'Histomolecular subtype'] = 'idhwt_ATC' |
|
|
424 |
|
|
|
425 |
idhmut_ATC = np.logical_and(data['Molecular subtype'] == 1, np.logical_or(data['Histology'] == 0, data['Histology'] == 3)) |
|
|
426 |
subtyped_data.loc[idhmut_ATC, 'Histomolecular subtype'] = 'idhmut_ATC' |
|
|
427 |
|
|
|
428 |
ODG = np.logical_and(data['Molecular subtype'] == 2, data['Histology'] == 2) |
|
|
429 |
subtyped_data.loc[ODG, 'Histomolecular subtype'] = 'ODG' |
|
|
430 |
return subtyped_data |
|
|
431 |
|
|
|
432 |
|
|
|
433 |
def changeHistomolecularSubtype(data): |
|
|
434 |
""" |
|
|
435 |
Molecular Subtype: IDHwt == 0, IDHmut-non-codel == 1, IDHmut-codel == 2 |
|
|
436 |
Histology Subtype: astrocytoma == 0, oligoastrocytoma == 1, oligodendroglioma == 2, glioblastoma == 3 |
|
|
437 |
""" |
|
|
438 |
data = data.drop(['Histomolecular subtype'], axis=1) |
|
|
439 |
subtyped_data = data.copy() |
|
|
440 |
subtyped_data.insert(loc=0, column='Histomolecular subtype', value=np.ones(len(data))) |
|
|
441 |
idhwt_ATC = np.logical_and(data['Molecular subtype'] == 0, np.logical_or(data['Histology'] == 0, data['Histology'] == 3)) |
|
|
442 |
subtyped_data.loc[idhwt_ATC, 'Histomolecular subtype'] = 'idhwt_ATC' |
|
|
443 |
|
|
|
444 |
idhmut_ATC = np.logical_and(data['Molecular subtype'] == 1, np.logical_or(data['Histology'] == 0, data['Histology'] == 3)) |
|
|
445 |
subtyped_data.loc[idhmut_ATC, 'Histomolecular subtype'] = 'idhmut_ATC' |
|
|
446 |
|
|
|
447 |
ODG = np.logical_and(data['Molecular subtype'] == 2, data['Histology'] == 2) |
|
|
448 |
subtyped_data.loc[ODG, 'Histomolecular subtype'] = 'ODG' |
|
|
449 |
return subtyped_data |
|
|
450 |
|
|
|
451 |
|
|
|
452 |
def getCleanAllDataset(dataroot='./data/TCGA_GBMLGG/', ignore_missing_moltype=False, ignore_missing_histype=False, use_rnaseq=False): |
|
|
453 |
### 1. Joining all_datasets.csv with grade data. Looks at columns with misisng samples |
|
|
454 |
metadata = ['Histology', 'Grade', 'Molecular subtype', 'TCGA ID', 'censored', 'Survival months'] |
|
|
455 |
all_dataset = pd.read_csv(os.path.join(dataroot, 'all_dataset.csv')).drop('indexes', axis=1) |
|
|
456 |
all_dataset.index = all_dataset['TCGA ID'] |
|
|
457 |
|
|
|
458 |
all_grade = pd.read_csv(os.path.join(dataroot, 'grade_data.csv')) |
|
|
459 |
all_grade['Histology'] = all_grade['Histology'].str.replace('astrocytoma (glioblastoma)', 'glioblastoma', regex=False) |
|
|
460 |
all_grade.index = all_grade['TCGA ID'] |
|
|
461 |
assert pd.Series(all_dataset.index).equals(pd.Series(sorted(all_grade.index))) |
|
|
462 |
|
|
|
463 |
all_dataset = all_dataset.join(all_grade[['Histology', 'Grade', 'Molecular subtype']], how='inner') |
|
|
464 |
cols = all_dataset.columns.tolist() |
|
|
465 |
cols = cols[-3:] + cols[:-3] |
|
|
466 |
all_dataset = all_dataset[cols] |
|
|
467 |
|
|
|
468 |
if use_rnaseq: |
|
|
469 |
gbm = pd.read_csv(os.path.join(dataroot, 'mRNA_Expression_z-Scores_RNA_Seq_RSEM.txt'), sep='\t', skiprows=1, index_col=0) |
|
|
470 |
lgg = pd.read_csv(os.path.join(dataroot, 'mRNA_Expression_Zscores_RSEM.txt'), sep='\t', skiprows=1, index_col=0) |
|
|
471 |
gbm = gbm[gbm.columns[~gbm.isnull().all()]] |
|
|
472 |
lgg = lgg[lgg.columns[~lgg.isnull().all()]] |
|
|
473 |
glioma_RNAseq = gbm.join(lgg, how='inner').T |
|
|
474 |
glioma_RNAseq = glioma_RNAseq.dropna(axis=1) |
|
|
475 |
glioma_RNAseq.columns = [gene+'_rnaseq' for gene in glioma_RNAseq.columns] |
|
|
476 |
glioma_RNAseq.index = [patname[:12] for patname in glioma_RNAseq.index] |
|
|
477 |
glioma_RNAseq = glioma_RNAseq.iloc[~glioma_RNAseq.index.duplicated()] |
|
|
478 |
glioma_RNAseq.index.name = 'TCGA ID' |
|
|
479 |
all_dataset = all_dataset.join(glioma_RNAseq, how='inner') |
|
|
480 |
|
|
|
481 |
pat_missing_moltype = all_dataset[all_dataset['Molecular subtype'].isna()].index |
|
|
482 |
pat_missing_idh = all_dataset[all_dataset['idh mutation'].isna()].index |
|
|
483 |
pat_missing_1p19q = all_dataset[all_dataset['codeletion'].isna()].index |
|
|
484 |
print("# Missing Molecular Subtype:", len(pat_missing_moltype)) |
|
|
485 |
print("# Missing IDH Mutation:", len(pat_missing_idh)) |
|
|
486 |
print("# Missing 1p19q Codeletion:", len(pat_missing_1p19q)) |
|
|
487 |
assert pat_missing_moltype.equals(pat_missing_idh) |
|
|
488 |
assert pat_missing_moltype.equals(pat_missing_1p19q) |
|
|
489 |
pat_missing_grade = all_dataset[all_dataset['Grade'].isna()].index |
|
|
490 |
pat_missing_histype = all_dataset[all_dataset['Histology'].isna()].index |
|
|
491 |
print("# Missing Histological Subtype:", len(pat_missing_histype)) |
|
|
492 |
print("# Missing Grade:", len(pat_missing_grade)) |
|
|
493 |
assert pat_missing_histype.equals(pat_missing_grade) |
|
|
494 |
|
|
|
495 |
### 2. Impute Missing Genomic Data: Removes patients with missing molecular subtype / idh mutation / 1p19q. Else imputes with median value of each column. Fills missing Molecular subtype with "Missing" |
|
|
496 |
if ignore_missing_moltype: |
|
|
497 |
all_dataset = all_dataset[all_dataset['Molecular subtype'].isna() == False] |
|
|
498 |
for col in all_dataset.drop(metadata, axis=1).columns: |
|
|
499 |
all_dataset['Molecular subtype'] = all_dataset['Molecular subtype'].fillna('Missing') |
|
|
500 |
all_dataset[col] = all_dataset[col].fillna(all_dataset[col].median()) |
|
|
501 |
|
|
|
502 |
### 3. Impute Missing Histological Data: Removes patients with missing histological subtype / grade. Else imputes with "missing" / grade -1 |
|
|
503 |
if ignore_missing_histype: |
|
|
504 |
all_dataset = all_dataset[all_dataset['Histology'].isna() == False] |
|
|
505 |
else: |
|
|
506 |
all_dataset['Grade'] = all_dataset['Grade'].fillna(1) |
|
|
507 |
all_dataset['Histology'] = all_dataset['Histology'].fillna('Missing') |
|
|
508 |
all_dataset['Grade'] = all_dataset['Grade'] - 2 |
|
|
509 |
|
|
|
510 |
### 4. Adds Histomolecular subtype |
|
|
511 |
ms2int = {'Missing':-1, 'IDHwt':0, 'IDHmut-non-codel':1, 'IDHmut-codel':2} |
|
|
512 |
all_dataset[['Molecular subtype']] = all_dataset[['Molecular subtype']].applymap(lambda s: ms2int.get(s) if s in ms2int else s) |
|
|
513 |
hs2int = {'Missing':-1, 'astrocytoma':0, 'oligoastrocytoma':1, 'oligodendroglioma':2, 'glioblastoma':3} |
|
|
514 |
all_dataset[['Histology']] = all_dataset[['Histology']].applymap(lambda s: hs2int.get(s) if s in hs2int else s) |
|
|
515 |
all_dataset = addHistomolecularSubtype(all_dataset) |
|
|
516 |
metadata.extend(['Histomolecular subtype']) |
|
|
517 |
all_dataset['censored'] = 1 - all_dataset['censored'] |
|
|
518 |
return metadata, all_dataset |
|
|
519 |
|
|
|
520 |
|
|
|
521 |
|
|
|
522 |
################ |
|
|
523 |
# Analysis Utils |
|
|
524 |
################ |
|
|
525 |
def count_parameters(model): |
|
|
526 |
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
527 |
|
|
|
528 |
|
|
|
529 |
def hazard2grade(hazard, p): |
|
|
530 |
if hazard < p[0]: |
|
|
531 |
return 0 |
|
|
532 |
elif hazard < p[1]: |
|
|
533 |
return 1 |
|
|
534 |
return 2 |
|
|
535 |
|
|
|
536 |
|
|
|
537 |
def p(n): |
|
|
538 |
def percentile_(x): |
|
|
539 |
return np.percentile(x, n) |
|
|
540 |
percentile_.__name__ = 'p%s' % n |
|
|
541 |
return percentile_ |
|
|
542 |
|
|
|
543 |
|
|
|
544 |
def natural_sort(l): |
|
|
545 |
convert = lambda text: int(text) if text.isdigit() else text.lower() |
|
|
546 |
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] |
|
|
547 |
return sorted(l, key = alphanum_key) |
|
|
548 |
|
|
|
549 |
|
|
|
550 |
def CI_pm(data, confidence=0.95): |
|
|
551 |
a = 1.0 * np.array(data) |
|
|
552 |
n = len(a) |
|
|
553 |
m, se = np.mean(a), scipy.stats.sem(a) |
|
|
554 |
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) |
|
|
555 |
return str("{0:.4f} ± ".format(m) + "{0:.3f}".format(h)) |
|
|
556 |
|
|
|
557 |
|
|
|
558 |
def CI_interval(data, confidence=0.95): |
|
|
559 |
a = 1.0 * np.array(data) |
|
|
560 |
n = len(a) |
|
|
561 |
m, se = np.mean(a), scipy.stats.sem(a) |
|
|
562 |
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1) |
|
|
563 |
return str("{0:.3f}, ".format(m-h) + "{0:.3f}".format(m+h)) |
|
|
564 |
|
|
|
565 |
|
|
|
566 |
def poolSurvTestPD(ckpt_name='./checkpoints/TCGA_GBMLGG/surv_15_rnaseq/', model='pathgraphomic_fusion', split='test', zscore=False, agg_type='Hazard_mean'): |
|
|
567 |
all_dataset_regstrd_pooled = [] |
|
|
568 |
ignore_missing_moltype = 1 if 'omic' in model else 0 |
|
|
569 |
ignore_missing_histype = 1 if 'grad' in ckpt_name else 0 |
|
|
570 |
use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if ((('path' in model) or ('graph' in model)) and ('cox' not in model)) else ('_', 'all_st', 0) |
|
|
571 |
use_rnaseq = '_rnaseq' if ('rnaseq' in ckpt_name and 'path' != model and 'pathpath' not in model and 'graph' != model and 'graphgraph' not in model) else '' |
|
|
572 |
|
|
|
573 |
for k in range(1,16): |
|
|
574 |
pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb')) |
|
|
575 |
|
|
|
576 |
if 'cox' not in model: |
|
|
577 |
surv_all = pd.DataFrame(np.stack(np.delete(np.array(pred), 3))).T |
|
|
578 |
surv_all.columns = ['Hazard', 'Survival months', 'censored', 'Grade'] |
|
|
579 |
data_cv = pickle.load(open('./data/TCGA_GBMLGG/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (roi_dir, ignore_missing_moltype, ignore_missing_histype, use_vgg_features, use_rnaseq), 'rb')) |
|
|
580 |
data_cv_splits = data_cv['cv_splits'] |
|
|
581 |
data_cv_split_k = data_cv_splits[k] |
|
|
582 |
assert np.all(data_cv_split_k[split]['t'] == pred[1]) # Data is correctly registered |
|
|
583 |
all_dataset = data_cv['data_pd'].drop('TCGA ID', axis=1) |
|
|
584 |
all_dataset_regstrd = all_dataset.loc[data_cv_split_k[split]['x_patname']] # Subset of "all_datasets" (metadata) that is registered with "pred" (predictions) |
|
|
585 |
assert np.all(np.array(all_dataset_regstrd['Survival months']) == pred[1]) |
|
|
586 |
assert np.all(np.array(all_dataset_regstrd['censored']) == pred[2]) |
|
|
587 |
assert np.all(np.array(all_dataset_regstrd['Grade']) == pred[4]) |
|
|
588 |
all_dataset_regstrd.insert(loc=0, column='Hazard', value = np.array(surv_all['Hazard'])) |
|
|
589 |
all_dataset_regstrd.index.name = 'TCGA ID' |
|
|
590 |
hazard_agg = all_dataset_regstrd.groupby('TCGA ID').agg({'Hazard': ['mean', 'median', max, p(0.25), p(0.75)]}) |
|
|
591 |
hazard_agg.columns = ["_".join(x) for x in hazard_agg.columns.ravel()] |
|
|
592 |
hazard_agg = hazard_agg[[agg_type]] |
|
|
593 |
hazard_agg.columns = ['Hazard'] |
|
|
594 |
pred = hazard_agg.join(all_dataset, how='inner') |
|
|
595 |
|
|
|
596 |
if zscore: pred['Hazard'] = scipy.stats.zscore(np.array(pred['Hazard'])) |
|
|
597 |
all_dataset_regstrd_pooled.append(pred) |
|
|
598 |
|
|
|
599 |
all_dataset_regstrd_pooled = pd.concat(all_dataset_regstrd_pooled) |
|
|
600 |
all_dataset_regstrd_pooled = changeHistomolecularSubtype(all_dataset_regstrd_pooled) |
|
|
601 |
return all_dataset_regstrd_pooled |
|
|
602 |
|
|
|
603 |
|
|
|
604 |
def getAggHazardCV(ckpt_name='./checkpoints/TCGA_GBMLGG/surv_15_rnaseq/', model='pathgraphomic_fusion', split='test', agg_type='Hazard_mean'): |
|
|
605 |
result = [] |
|
|
606 |
|
|
|
607 |
ignore_missing_moltype = 1 if 'omic' in model else 0 |
|
|
608 |
ignore_missing_histype = 1 if 'grad' in ckpt_name else 0 |
|
|
609 |
use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0) |
|
|
610 |
use_rnaseq = '_rnaseq' if ('rnaseq' in ckpt_name and 'path' != model and 'pathpath' not in model and 'graph' != model and 'graphgraph' not in model) else '' |
|
|
611 |
|
|
|
612 |
for k in range(1,16): |
|
|
613 |
pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb')) |
|
|
614 |
surv_all = pd.DataFrame(np.stack(np.delete(np.array(pred), 3))).T |
|
|
615 |
surv_all.columns = ['Hazard', 'Survival months', 'censored', 'Grade'] |
|
|
616 |
data_cv = pickle.load(open('./data/TCGA_GBMLGG/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (roi_dir, ignore_missing_moltype, ignore_missing_histype, use_vgg_features, use_rnaseq), 'rb')) |
|
|
617 |
data_cv_splits = data_cv['cv_splits'] |
|
|
618 |
data_cv_split_k = data_cv_splits[k] |
|
|
619 |
assert np.all(data_cv_split_k[split]['t'] == pred[1]) # Data is correctly registered |
|
|
620 |
all_dataset = data_cv['data_pd'].drop('TCGA ID', axis=1) |
|
|
621 |
all_dataset_regstrd = all_dataset.loc[data_cv_split_k[split]['x_patname']] # Subset of "all_datasets" (metadata) that is registered with "pred" (predictions) |
|
|
622 |
assert np.all(np.array(all_dataset_regstrd['Survival months']) == pred[1]) |
|
|
623 |
assert np.all(np.array(all_dataset_regstrd['censored']) == pred[2]) |
|
|
624 |
assert np.all(np.array(all_dataset_regstrd['Grade']) == pred[4]) |
|
|
625 |
all_dataset_regstrd.insert(loc=0, column='Hazard', value = np.array(surv_all['Hazard'])) |
|
|
626 |
all_dataset_regstrd.index.name = 'TCGA ID' |
|
|
627 |
hazard_agg = all_dataset_regstrd.groupby('TCGA ID').agg({'Hazard': ['mean', max, p(0.75)]}) |
|
|
628 |
hazard_agg.columns = ["_".join(x) for x in hazard_agg.columns.ravel()] |
|
|
629 |
hazard_agg = hazard_agg[[agg_type]] |
|
|
630 |
hazard_agg.columns = ['Hazard'] |
|
|
631 |
all_dataset_hazard = hazard_agg.join(all_dataset, how='inner') |
|
|
632 |
cin = CIndex_lifeline(all_dataset_hazard['Hazard'], all_dataset_hazard['censored'], all_dataset_hazard['Survival months']) |
|
|
633 |
result.append(cin) |
|
|
634 |
|
|
|
635 |
return result |
|
|
636 |
|
|
|
637 |
|
|
|
638 |
def calcGradMetrics(ckpt_name='./checkpoints/grad_15/', model='pathgraphomic_fusion', split='test', avg='micro'): |
|
|
639 |
auc_all = [] |
|
|
640 |
ap_all = [] |
|
|
641 |
f1_all = [] |
|
|
642 |
f1_gradeIV_all = [] |
|
|
643 |
|
|
|
644 |
ignore_missing_moltype = 1 if 'omic' in model else 0 |
|
|
645 |
ignore_missing_histype = 1 if 'grad' in ckpt_name else 0 |
|
|
646 |
use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0) |
|
|
647 |
|
|
|
648 |
for k in range(1,16): |
|
|
649 |
pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb')) |
|
|
650 |
grade_pred, grade = np.array(pred[3]), np.array(pred[4]) |
|
|
651 |
enc = LabelBinarizer() |
|
|
652 |
enc.fit(grade) |
|
|
653 |
grade_oh = enc.transform(grade) |
|
|
654 |
rocauc = roc_auc_score(grade_oh, grade_pred, avg) |
|
|
655 |
ap = average_precision_score(grade_oh, grade_pred, average=avg) |
|
|
656 |
f1 = f1_score(grade_pred.argmax(axis=1), grade, average=avg) |
|
|
657 |
f1_gradeIV = f1_score(grade_pred.argmax(axis=1), grade, average=None)[2] |
|
|
658 |
|
|
|
659 |
auc_all.append(rocauc) |
|
|
660 |
ap_all.append(ap) |
|
|
661 |
f1_all.append(f1) |
|
|
662 |
f1_gradeIV_all.append(f1_gradeIV) |
|
|
663 |
|
|
|
664 |
return np.array([CI_pm(auc_all), CI_pm(ap_all), CI_pm(f1_all), CI_pm(f1_gradeIV_all)]) |
|
|
665 |
|
|
|
666 |
|
|
|
667 |
|
|
|
668 |
################ |
|
|
669 |
# Plot Utils |
|
|
670 |
################ |
|
|
671 |
def makeKaplanMeierPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='omic', split='test', zscore=False, agg_type='Hazard_mean'): |
|
|
672 |
def hazard2KMCurve(data, subtype): |
|
|
673 |
p = np.percentile(data['Hazard'], [33, 66]) |
|
|
674 |
if p[0] == p[1]: p[0] = 2.99997 |
|
|
675 |
data.insert(0, 'grade_pred', [hazard2grade(hazard, p) for hazard in data['Hazard']]) |
|
|
676 |
kmf_pred = lifelines.KaplanMeierFitter() |
|
|
677 |
kmf_gt = lifelines.KaplanMeierFitter() |
|
|
678 |
|
|
|
679 |
def get_name(model): |
|
|
680 |
mode2name = {'pathgraphomic':'Pathomic F.', 'pathomic':'Pathomic F.', 'graphomic':'Pathomic F.', 'path':'Histology CNN', 'graph':'Histology GCN', 'omic':'Genomic SNN'} |
|
|
681 |
for mode in mode2name.keys(): |
|
|
682 |
if mode in model: return mode2name[mode] |
|
|
683 |
return 'N/A' |
|
|
684 |
|
|
|
685 |
fig = plt.figure(figsize=(10, 10), dpi=600) |
|
|
686 |
ax = plt.subplot() |
|
|
687 |
censor_style = {'ms': 20, 'marker': '+'} |
|
|
688 |
|
|
|
689 |
temp = data[data['Grade']==0] |
|
|
690 |
kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade II") |
|
|
691 |
kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='g', linewidth=3, ls='--', markerfacecolor='black', censor_styles=censor_style) |
|
|
692 |
temp = data[data['grade_pred']==0] |
|
|
693 |
kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (Low)" % get_name(model)) |
|
|
694 |
kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='g', linewidth=4, ls='-', markerfacecolor='black', censor_styles=censor_style) |
|
|
695 |
|
|
|
696 |
temp = data[data['Grade']==1] |
|
|
697 |
kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade III") |
|
|
698 |
kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='b', linewidth=3, ls='--', censor_styles=censor_style) |
|
|
699 |
temp = data[data['grade_pred']==1] |
|
|
700 |
kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (Mid)" % get_name(model)) |
|
|
701 |
kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='b', linewidth=4, ls='-', censor_styles=censor_style) |
|
|
702 |
|
|
|
703 |
if subtype != 'ODG': |
|
|
704 |
temp = data[data['Grade']==2] |
|
|
705 |
kmf_gt.fit(temp['Survival months']/365, temp['censored'], label="Grade IV") |
|
|
706 |
kmf_gt.plot(ax=ax, show_censors=True, ci_show=False, c='r', linewidth=3, ls='--', censor_styles=censor_style) |
|
|
707 |
temp = data[data['grade_pred']==2] |
|
|
708 |
kmf_pred.fit(temp['Survival months']/365, temp['censored'], label="%s (High)" % get_name(model)) |
|
|
709 |
kmf_pred.plot(ax=ax, show_censors=True, ci_show=False, c='r', linewidth=4, ls='-', censor_styles=censor_style) |
|
|
710 |
|
|
|
711 |
ax.set_xlabel('') |
|
|
712 |
ax.set_ylim(0, 1) |
|
|
713 |
ax.set_yticks(np.arange(0, 1.001, 0.5)) |
|
|
714 |
|
|
|
715 |
ax.tick_params(axis='both', which='major', labelsize=40) |
|
|
716 |
plt.legend(fontsize=32, prop=font_manager.FontProperties(family='Arial', style='normal', size=32)) |
|
|
717 |
if subtype != 'idhwt_ATC': ax.get_legend().remove() |
|
|
718 |
return fig |
|
|
719 |
|
|
|
720 |
data = poolSurvTestPD(ckpt_name, model, split, zscore, agg_type) |
|
|
721 |
for subtype in ['idhwt_ATC', 'idhmut_ATC', 'ODG']: |
|
|
722 |
fig = hazard2KMCurve(data[data['Histomolecular subtype'] == subtype], subtype) |
|
|
723 |
fig.savefig(ckpt_name+'/%s_KM_%s.png' % (model, subtype)) |
|
|
724 |
|
|
|
725 |
fig = hazard2KMCurve(data, 'all') |
|
|
726 |
fig.savefig(ckpt_name+'/%s_KM_%s.png' % (model, 'all')) |
|
|
727 |
|
|
|
728 |
|
|
|
729 |
def makeHazardSwarmPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='path', split='test', zscore=True, agg_type='Hazard_mean'): |
|
|
730 |
mpl.rcParams['font.family'] = "arial" |
|
|
731 |
data = poolSurvTestPD(ckpt_name=ckpt_name, model=model, split=split, zscore=zscore, agg_type=agg_type) |
|
|
732 |
data = data[data['Grade'] != -1] |
|
|
733 |
data = data[data['Histomolecular subtype'] != -1] |
|
|
734 |
data['Grade'] = data['Grade'].astype(int).astype(str) |
|
|
735 |
data['Grade'] = data['Grade'].str.replace('0', 'Grade II', regex=False) |
|
|
736 |
data['Grade'] = data['Grade'].str.replace('1', 'Grade III', regex=False) |
|
|
737 |
data['Grade'] = data['Grade'].str.replace('2', 'Grade IV', regex=False) |
|
|
738 |
data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('idhwt_ATC', 'IDH-wt \n astryocytoma', regex=False) |
|
|
739 |
data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('idhmut_ATC', 'IDH-mut \n astrocytoma', regex=False) |
|
|
740 |
data['Histomolecular subtype'] = data['Histomolecular subtype'].str.replace('ODG', 'Oligodendroglioma', regex=False) |
|
|
741 |
|
|
|
742 |
fig, ax = plt.subplots(dpi=600) |
|
|
743 |
ax.set_ylim([-2, 2.5]) # plt.ylim(-2, 2) |
|
|
744 |
ax.spines['right'].set_visible(False) |
|
|
745 |
ax.spines['top'].set_visible(False) |
|
|
746 |
ax.set_yticks(np.arange(-2, 2.001, 1)) |
|
|
747 |
|
|
|
748 |
sns.swarmplot(x = 'Histomolecular subtype', y='Hazard', data=data, hue='Grade', |
|
|
749 |
palette={"Grade II":"#AFD275" , "Grade III":"#7395AE", "Grade IV":"#E7717D"}, |
|
|
750 |
size = 4, alpha = 0.9, ax=ax) |
|
|
751 |
|
|
|
752 |
ax.set_xlabel('') # ax.set_xlabel('Histomolecular subtype', size=16) |
|
|
753 |
ax.set_ylabel('') # ax.set_ylabel('Hazard (Z-Score)', size=16) |
|
|
754 |
ax.tick_params(axis='y', which='both', labelsize=20) |
|
|
755 |
ax.tick_params(axis='x', which='both', labelsize=15) |
|
|
756 |
ax.tick_params(axis='x', which='both', labelbottom='off') # doesn't work?? |
|
|
757 |
ax.legend(prop={'size': 8}) |
|
|
758 |
fig.savefig(ckpt_name+'/%s_HSP.png' % (model)) |
|
|
759 |
|
|
|
760 |
|
|
|
761 |
def makeHazardBoxPlot(ckpt_name='./checkpoints/surv_15_rnaseq/', model='omic', split='test', zscore=True, agg_type='Hazard_mean'): |
|
|
762 |
mpl.rcParams['font.family'] = "arial" |
|
|
763 |
data = poolSurvTestPD(ckpt_name, model, split, zscore, 'Hazard_mean') |
|
|
764 |
data['Grade'] = data['Grade'].astype(int).astype(str) |
|
|
765 |
data['Grade'] = data['Grade'].str.replace('0', 'II', regex=False) |
|
|
766 |
data['Grade'] = data['Grade'].str.replace('1', 'III', regex=False) |
|
|
767 |
data['Grade'] = data['Grade'].str.replace('2', 'IV', regex=False) |
|
|
768 |
|
|
|
769 |
fig, axes = plt.subplots(nrows=1, ncols=3, gridspec_kw={'width_ratios': [3, 3, 2]}, dpi=600) |
|
|
770 |
plt.subplots_adjust(wspace=0, hspace=0) |
|
|
771 |
plt.ylim(-2, 2) |
|
|
772 |
plt.yticks(np.arange(-2, 2.001, 1)) |
|
|
773 |
#color_dict = {0: '#CF9498', 1: '#8CC7C8', 2: '#AAA0C6'} |
|
|
774 |
#color_dict = {0: '#F76C6C', 1: '#A8D0E6', 2: '#F8E9A1'} |
|
|
775 |
color_dict = ['#F76C6C', '#A8D0E6', '#F8E9A1'] |
|
|
776 |
subtypes = ['idhwt_ATC', 'idhmut_ATC', 'ODG'] |
|
|
777 |
|
|
|
778 |
for i in range(len(subtypes)): |
|
|
779 |
axes[i].spines["top"].set_visible(False) |
|
|
780 |
axes[i].spines["right"].set_visible(False) |
|
|
781 |
axes[i].xaxis.grid(False) |
|
|
782 |
axes[i].yaxis.grid(False) |
|
|
783 |
|
|
|
784 |
if i > 0: |
|
|
785 |
axes[i].get_yaxis().set_visible(False) |
|
|
786 |
axes[i].spines["left"].set_visible(False) |
|
|
787 |
|
|
|
788 |
order = ["II","III","IV"] if subtypes[i] != 'ODG' else ["II", "III"] |
|
|
789 |
|
|
|
790 |
axes[i].xaxis.label.set_visible(False) |
|
|
791 |
axes[i].yaxis.label.set_visible(False) |
|
|
792 |
axes[i].tick_params(axis='y', which='both', labelsize=20) |
|
|
793 |
axes[i].tick_params(axis='x', which='both', labelsize=15) |
|
|
794 |
datapoints = data[data['Histomolecular subtype'] == subtypes[i]] |
|
|
795 |
sns.boxplot(y='Hazard', x="Grade", data=datapoints, ax = axes[i], color=color_dict[i], order=order) |
|
|
796 |
sns.stripplot(y='Hazard', x='Grade', data=datapoints, alpha=0.2, jitter=0.2, color='k', ax = axes[i], order=order) |
|
|
797 |
axes[i].set_ylim(-2.5, 2.5) |
|
|
798 |
axes[i].set_yticks(np.arange(-2.0, 2.1, 1)) |
|
|
799 |
|
|
|
800 |
#axes[2].legend(prop={'size': 10}) |
|
|
801 |
fig.savefig(ckpt_name+'/%s_HBP.png' % (model)) |
|
|
802 |
|
|
|
803 |
|
|
|
804 |
def makeAUROCPlot(ckpt_name='./checkpoints/grad_15/', model_list=['path', 'omic', 'pathgraphomic_fusion'], split='test', avg='micro', use_zoom=False): |
|
|
805 |
mpl.rcParams['font.family'] = "arial" |
|
|
806 |
colors = {'path':'dodgerblue', 'graph':'orange', 'omic':'green', 'pathgraphomic_fusion':'crimson'} |
|
|
807 |
names = {'path':'Histology CNN', 'graph':'Histology GCN', 'omic':'Genomic SNN', 'pathgraphomic_fusion':'Pathomic F.'} |
|
|
808 |
zoom_params = {0:([0.2, 0.4], [0.8, 1.0]), |
|
|
809 |
1:([0.25, 0.45], [0.75, 0.95]), |
|
|
810 |
2:([0.0, 0.2], [0.8, 1.0]), |
|
|
811 |
'micro':([0.15, 0.35], [0.8, 1.0])} |
|
|
812 |
mean_fpr = np.linspace(0, 1, 100) |
|
|
813 |
classes = [0, 1, 2, avg] |
|
|
814 |
### 1. Looping over classes |
|
|
815 |
for i in classes: |
|
|
816 |
print("Class: " + str(i)) |
|
|
817 |
fi = pylab.figure(figsize=(10,10), dpi=600, linewidth=0.2) |
|
|
818 |
axi = plt.subplot() |
|
|
819 |
|
|
|
820 |
### 2. Looping over models |
|
|
821 |
for m, model in enumerate(model_list): |
|
|
822 |
ignore_missing_moltype = 1 if 'omic' in model else 0 |
|
|
823 |
ignore_missing_histype = 1 if 'grad' in ckpt_name else 0 |
|
|
824 |
use_patch, roi_dir, use_vgg_features = ('_patch_', 'all_st_patches_512', 1) if (('path' in model) or ('graph' in model)) else ('_', 'all_st', 0) |
|
|
825 |
|
|
|
826 |
###. 3. Looping over all splits |
|
|
827 |
tprs, pres, aucrocs, rocaucs, = [], [], [], [] |
|
|
828 |
for k in range(1,16): |
|
|
829 |
pred = pickle.load(open(ckpt_name+'/%s/%s_%d%spred_%s.pkl' % (model, model, k, use_patch, split), 'rb')) |
|
|
830 |
grade_pred, grade = np.array(pred[3]), np.array(pred[4]) |
|
|
831 |
enc = LabelBinarizer() |
|
|
832 |
enc.fit(grade) |
|
|
833 |
grade_oh = enc.transform(grade) |
|
|
834 |
|
|
|
835 |
if i != avg: |
|
|
836 |
pres.append(average_precision_score(grade_oh[:, i], grade_pred[:, i])) # from https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html |
|
|
837 |
fpr, tpr, thresh = roc_curve(grade_oh[:,i], grade_pred[:,i], drop_intermediate=False) |
|
|
838 |
aucrocs.append(auc(fpr, tpr)) # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html |
|
|
839 |
rocaucs.append(roc_auc_score(grade_oh[:,i], grade_pred[:,i])) # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score |
|
|
840 |
tprs.append(interp(mean_fpr, fpr, tpr)) |
|
|
841 |
tprs[-1][0] = 0.0 |
|
|
842 |
else: |
|
|
843 |
# A "micro-average": quantifying score on all classes jointly |
|
|
844 |
pres.append(average_precision_score(grade_oh, grade_pred, average=avg)) |
|
|
845 |
fpr, tpr, thresh = roc_curve(grade_oh.ravel(), grade_pred.ravel()) |
|
|
846 |
aucrocs.append(auc(fpr, tpr)) |
|
|
847 |
rocaucs.append(roc_auc_score(grade_oh, grade_pred, avg)) |
|
|
848 |
tprs.append(interp(mean_fpr, fpr, tpr)) |
|
|
849 |
tprs[-1][0] = 0.0 |
|
|
850 |
|
|
|
851 |
mean_tpr = np.mean(tprs, axis=0) |
|
|
852 |
mean_tpr[-1] = 1.0 |
|
|
853 |
#mean_auc = auc(mean_fpr, mean_tpr) |
|
|
854 |
mean_auc = np.mean(aucrocs) |
|
|
855 |
std_auc = np.std(aucrocs) |
|
|
856 |
print('\t'+'%s - AUC: %0.3f ± %0.3f' % (model, mean_auc, std_auc)) |
|
|
857 |
|
|
|
858 |
if use_zoom: |
|
|
859 |
alpha, lw = (0.8, 6) if model =='pathgraphomic_fusion' else (0.5, 6) |
|
|
860 |
plt.plot(mean_fpr, mean_tpr, color=colors[model], |
|
|
861 |
label=r'%s (AUC = %0.3f $\pm$ %0.3f)' % (names[model], mean_auc, std_auc), lw=lw, alpha=alpha) |
|
|
862 |
std_tpr = np.std(tprs, axis=0) |
|
|
863 |
tprs_upper = np.minimum(mean_tpr + std_tpr, 1) |
|
|
864 |
tprs_lower = np.maximum(mean_tpr - std_tpr, 0) |
|
|
865 |
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model], alpha=0.1) |
|
|
866 |
plt.xlim([zoom_params[i][0][0]-0.005, zoom_params[i][0][1]+0.005]) |
|
|
867 |
plt.ylim([zoom_params[i][1][0]-0.005, zoom_params[i][1][1]+0.005]) |
|
|
868 |
axi.set_xticks(np.arange(zoom_params[i][0][0], zoom_params[i][0][1]+0.001, 0.05)) |
|
|
869 |
axi.set_yticks(np.arange(zoom_params[i][1][0], zoom_params[i][1][1]+0.001, 0.05)) |
|
|
870 |
axi.tick_params(axis='both', which='major', labelsize=26) |
|
|
871 |
else: |
|
|
872 |
alpha, lw = (0.8, 4) if model =='pathgraphomic_fusion' else (0.5, 3) |
|
|
873 |
plt.plot(mean_fpr, mean_tpr, color=colors[model], |
|
|
874 |
label=r'%s (AUC = %0.3f $\pm$ %0.3f)' % (names[model], mean_auc, std_auc), lw=lw, alpha=alpha) |
|
|
875 |
std_tpr = np.std(tprs, axis=0) |
|
|
876 |
tprs_upper = np.minimum(mean_tpr + std_tpr, 1) |
|
|
877 |
tprs_lower = np.maximum(mean_tpr - std_tpr, 0) |
|
|
878 |
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[model], alpha=0.1) |
|
|
879 |
plt.xlim([-0.05, 1.05]) |
|
|
880 |
plt.ylim([-0.05, 1.05]) |
|
|
881 |
axi.set_xticks(np.arange(0, 1.001, 0.2)) |
|
|
882 |
axi.set_yticks(np.arange(0, 1.001, 0.2)) |
|
|
883 |
axi.legend(loc="lower right", prop={'size': 20}) |
|
|
884 |
axi.tick_params(axis='both', which='major', labelsize=30) |
|
|
885 |
#plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='navy', alpha=.8) |
|
|
886 |
|
|
|
887 |
figures = [manager.canvas.figure |
|
|
888 |
for manager in mpl._pylab_helpers.Gcf.get_all_fig_managers()] |
|
|
889 |
|
|
|
890 |
zoom = '_zoom' if use_zoom else '' |
|
|
891 |
for i, fig in enumerate(figures): |
|
|
892 |
fig.savefig(ckpt_name+'/AUC_%s%s.png' % (classes[i], zoom)) |