|
a |
|
b/model/SALMON.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
""" |
|
|
4 |
@author: Zhi Huang |
|
|
5 |
""" |
|
|
6 |
import argparse, random |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import torch.backends.cudnn as cudnn |
|
|
10 |
import torch.nn.functional as F |
|
|
11 |
import torch.optim as optim |
|
|
12 |
from torch.utils.data import DataLoader |
|
|
13 |
from torchvision import datasets, transforms |
|
|
14 |
from torch.autograd import Variable |
|
|
15 |
from collections import Counter |
|
|
16 |
import pandas as pd |
|
|
17 |
import matplotlib.pyplot as plt |
|
|
18 |
import math |
|
|
19 |
import random |
|
|
20 |
from imblearn.over_sampling import RandomOverSampler |
|
|
21 |
import pandas as pd |
|
|
22 |
from lifelines.statistics import logrank_test |
|
|
23 |
from lifelines.utils import concordance_index |
|
|
24 |
import tables |
|
|
25 |
import csv |
|
|
26 |
import numpy as np |
|
|
27 |
import json |
|
|
28 |
from tqdm import tqdm |
|
|
29 |
import gc |
|
|
30 |
import copy |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
class SALMON(nn.Module): |
|
|
34 |
def __init__(self, input_dim, dropout_rate, length_of_data, label_dim): |
|
|
35 |
super(SALMON, self).__init__() |
|
|
36 |
|
|
|
37 |
self.length_of_data = length_of_data |
|
|
38 |
hidden1 = 8 |
|
|
39 |
hidden2 = 4 |
|
|
40 |
|
|
|
41 |
if input_dim == length_of_data['mRNAseq']: # mRNAseq |
|
|
42 |
self.encoder1 = nn.Sequential(nn.Linear(input_dim, hidden1),nn.Sigmoid()) |
|
|
43 |
self.classifier = nn.Sequential(nn.Linear(hidden1, label_dim),nn.Sigmoid()) |
|
|
44 |
|
|
|
45 |
if input_dim == length_of_data['miRNAseq']: # miRNAseq |
|
|
46 |
self.encoder2 = nn.Sequential(nn.Linear(input_dim, hidden2),nn.Sigmoid()) |
|
|
47 |
self.classifier = nn.Sequential(nn.Linear(hidden2, label_dim),nn.Sigmoid()) |
|
|
48 |
|
|
|
49 |
if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq']: # mRNAseq + miRNAseq |
|
|
50 |
self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid()) |
|
|
51 |
self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid()) |
|
|
52 |
self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2, label_dim),nn.Sigmoid()) |
|
|
53 |
|
|
|
54 |
if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq'] + length_of_data['CNB'] + length_of_data['TMB']: # mRNAseq + miRNAseq + CNB + TMB |
|
|
55 |
hidden_cnv, hidden_tmb = length_of_data['CNB'], length_of_data['TMB'] |
|
|
56 |
self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid()) |
|
|
57 |
self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid()) |
|
|
58 |
self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2 + hidden_cnv + hidden_tmb, label_dim),nn.Sigmoid()) |
|
|
59 |
|
|
|
60 |
if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq'] + length_of_data['CNB'] + length_of_data['TMB'] + length_of_data['clinical']: # mRNAseq + miRNAseq + CNB + TMB + clinical |
|
|
61 |
hidden_cnv, hidden_tmb, hidden_clinical = length_of_data['CNB'], length_of_data['TMB'], length_of_data['clinical'] |
|
|
62 |
self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid()) |
|
|
63 |
self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid()) |
|
|
64 |
self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2 + \ |
|
|
65 |
hidden_cnv + hidden_tmb + hidden_clinical, label_dim),nn.Sigmoid()) |
|
|
66 |
|
|
|
67 |
if input_dim == length_of_data['CNB'] + length_of_data['TMB'] + length_of_data['clinical']: # CNB + TMB + clinical |
|
|
68 |
hidden_cnv, hidden_tmb, hidden_clinical = length_of_data['CNB'], length_of_data['TMB'], length_of_data['clinical'] |
|
|
69 |
self.classifier = nn.Sequential(nn.Linear(hidden_cnv + hidden_tmb + hidden_clinical, label_dim),nn.Sigmoid()) |
|
|
70 |
|
|
|
71 |
if input_dim == length_of_data['mRNAseq'] + length_of_data['miRNAseq'] + length_of_data['clinical']: # mRNAseq + miRNAseq + clinical |
|
|
72 |
hidden_clinical = length_of_data['clinical'] |
|
|
73 |
self.encoder1 = nn.Sequential(nn.Linear(length_of_data['mRNAseq'], hidden1),nn.Sigmoid()) |
|
|
74 |
self.encoder2 = nn.Sequential(nn.Linear(length_of_data['miRNAseq'], hidden2),nn.Sigmoid()) |
|
|
75 |
self.classifier = nn.Sequential(nn.Linear(hidden1 + hidden2 + \ |
|
|
76 |
hidden_clinical, label_dim),nn.Sigmoid()) |
|
|
77 |
|
|
|
78 |
def forward(self, x): |
|
|
79 |
input_dim = x.shape[1] |
|
|
80 |
x_d = None |
|
|
81 |
if input_dim == self.length_of_data['mRNAseq']: # mRNAseq |
|
|
82 |
code1 = self.encoder1(x) |
|
|
83 |
lbl_pred = self.classifier(code1) # predicted label |
|
|
84 |
code = code1 |
|
|
85 |
|
|
|
86 |
if input_dim == self.length_of_data['miRNAseq']: # miRNAseq |
|
|
87 |
code2 = self.encoder2(x) |
|
|
88 |
lbl_pred = self.classifier(code2) # predicted label |
|
|
89 |
code = code2 |
|
|
90 |
|
|
|
91 |
if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']: # mRNAseq + miRNAseq |
|
|
92 |
code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']]) |
|
|
93 |
code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']:]) |
|
|
94 |
lbl_pred = self.classifier(torch.cat((code1, code2), 1)) # predicted label |
|
|
95 |
code = torch.cat((code1, code2), 1) |
|
|
96 |
|
|
|
97 |
if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'] + self.length_of_data['CNB'] + self.length_of_data['TMB']: # mRNAseq + miRNAseq + CNB + TMB |
|
|
98 |
code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']]) |
|
|
99 |
code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']: (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'])]) |
|
|
100 |
lbl_pred = self.classifier(torch.cat((code1, code2, x[:,(self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']):]), 1)) # predicted label |
|
|
101 |
code = torch.cat((code1, code2), 1) |
|
|
102 |
|
|
|
103 |
if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'] + self.length_of_data['CNB'] + self.length_of_data['TMB'] + self.length_of_data['clinical']: # mRNAseq + miRNAseq + CNB + TMB + clinical |
|
|
104 |
code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']]) |
|
|
105 |
code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']: (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'])]) |
|
|
106 |
lbl_pred = self.classifier(torch.cat((code1, code2, x[:, (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']):]), 1)) # predicted label |
|
|
107 |
code = torch.cat((code1, code2), 1) |
|
|
108 |
|
|
|
109 |
if input_dim == self.length_of_data['CNB'] + self.length_of_data['TMB'] + self.length_of_data['clinical']: # CNB + TMB + clinical |
|
|
110 |
lbl_pred = self.classifier(x) # predicted label |
|
|
111 |
code = torch.FloatTensor([0]) |
|
|
112 |
|
|
|
113 |
if input_dim == self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'] + self.length_of_data['clinical']: # mRNAseq + miRNAseq + clinical |
|
|
114 |
code1 = self.encoder1(x[:,0:self.length_of_data['mRNAseq']]) |
|
|
115 |
code2 = self.encoder2(x[:,self.length_of_data['mRNAseq']: (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq'])]) |
|
|
116 |
lbl_pred = self.classifier(torch.cat((code1, code2, x[:, (self.length_of_data['mRNAseq'] + self.length_of_data['miRNAseq']):]), 1)) # predicted label |
|
|
117 |
code = torch.cat((code1, code2), 1) |
|
|
118 |
|
|
|
119 |
return x_d, code, lbl_pred |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def accuracy(output, labels): |
|
|
123 |
preds = output.max(1)[1].type_as(labels) |
|
|
124 |
correct = preds.eq(labels).double() |
|
|
125 |
correct = correct.sum() |
|
|
126 |
return correct / len(labels) |
|
|
127 |
|
|
|
128 |
def accuracy_cox(hazards, labels): |
|
|
129 |
# This accuracy is based on estimated survival events against true survival events |
|
|
130 |
hazardsdata = hazards.cpu().numpy().reshape(-1) |
|
|
131 |
median = np.median(hazardsdata) |
|
|
132 |
hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int) |
|
|
133 |
hazards_dichotomize[hazardsdata > median] = 1 |
|
|
134 |
labels = labels.data.cpu().numpy() |
|
|
135 |
correct = np.sum(hazards_dichotomize == labels) |
|
|
136 |
return correct / len(labels) |
|
|
137 |
|
|
|
138 |
def cox_log_rank(hazards, labels, survtime_all): |
|
|
139 |
hazardsdata = hazards.cpu().numpy().reshape(-1) |
|
|
140 |
median = np.median(hazardsdata) |
|
|
141 |
hazards_dichotomize = np.zeros([len(hazardsdata)], dtype=int) |
|
|
142 |
hazards_dichotomize[hazardsdata > median] = 1 |
|
|
143 |
survtime_all = survtime_all.data.cpu().numpy().reshape(-1) |
|
|
144 |
idx = hazards_dichotomize == 0 |
|
|
145 |
labels = labels.data.cpu().numpy() |
|
|
146 |
T1 = survtime_all[idx] |
|
|
147 |
T2 = survtime_all[~idx] |
|
|
148 |
E1 = labels[idx] |
|
|
149 |
E2 = labels[~idx] |
|
|
150 |
results = logrank_test(T1, T2, event_observed_A=E1, event_observed_B=E2) |
|
|
151 |
pvalue_pred = results.p_value |
|
|
152 |
return(pvalue_pred) |
|
|
153 |
|
|
|
154 |
def CIndex(hazards, labels, survtime_all): |
|
|
155 |
labels = labels.data.cpu().numpy() |
|
|
156 |
concord = 0. |
|
|
157 |
total = 0. |
|
|
158 |
N_test = labels.shape[0] |
|
|
159 |
labels = np.asarray(labels, dtype=bool) |
|
|
160 |
for i in range(N_test): |
|
|
161 |
if labels[i] == 1: |
|
|
162 |
for j in range(N_test): |
|
|
163 |
if survtime_all[j] > survtime_all[i]: |
|
|
164 |
total = total + 1 |
|
|
165 |
if hazards[j] < hazards[i]: concord = concord + 1 |
|
|
166 |
elif hazards[j] < hazards[i]: concord = concord + 0.5 |
|
|
167 |
|
|
|
168 |
return(concord/total) |
|
|
169 |
|
|
|
170 |
def CIndex_lifeline(hazards, labels, survtime_all): |
|
|
171 |
labels = labels.data.cpu().numpy() |
|
|
172 |
hazards = hazards.cpu().numpy().reshape(-1) |
|
|
173 |
return(concordance_index(survtime_all, -hazards, labels)) |
|
|
174 |
|
|
|
175 |
def frobenius_norm_loss(a, b): |
|
|
176 |
loss = torch.sqrt(torch.sum(torch.abs(a-b)**2)) |
|
|
177 |
return loss |
|
|
178 |
|
|
|
179 |
def test(model, datasets, whichset, length_of_data, batch_size, cuda, verbose): |
|
|
180 |
x = datasets[whichset]['x'] |
|
|
181 |
e = datasets[whichset]['e'] |
|
|
182 |
t = datasets[whichset]['t'] |
|
|
183 |
X = torch.FloatTensor(x) |
|
|
184 |
OS_event = torch.LongTensor(e) |
|
|
185 |
OS = torch.FloatTensor(t) |
|
|
186 |
dataloader = DataLoader(X, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=False) |
|
|
187 |
lblloader = DataLoader(OS_event, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=False) |
|
|
188 |
OSloader = DataLoader(OS, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=False) |
|
|
189 |
lbl_pred_all = None |
|
|
190 |
lbl_all = None |
|
|
191 |
survtime_all = None |
|
|
192 |
code_final = None |
|
|
193 |
loss_nn_sum = 0 |
|
|
194 |
model.eval() |
|
|
195 |
iter = 0 |
|
|
196 |
for data, lbl, survtime in zip(dataloader, lblloader, OSloader): |
|
|
197 |
graph = data |
|
|
198 |
graph = Variable(graph) |
|
|
199 |
lbl = Variable(lbl) |
|
|
200 |
if cuda: |
|
|
201 |
model = model.cuda() |
|
|
202 |
graph = graph.cuda() |
|
|
203 |
lbl = lbl.cuda() |
|
|
204 |
# ===================forward===================== |
|
|
205 |
output, code, lbl_pred = model(graph) |
|
|
206 |
if iter == 0: |
|
|
207 |
lbl_pred_all = lbl_pred |
|
|
208 |
lbl_all = lbl |
|
|
209 |
survtime_all = survtime |
|
|
210 |
code_final = code |
|
|
211 |
else: |
|
|
212 |
lbl_pred_all = torch.cat([lbl_pred_all, lbl_pred]) |
|
|
213 |
lbl_all = torch.cat([lbl_all, lbl]) |
|
|
214 |
survtime_all = torch.cat([survtime_all, survtime]) |
|
|
215 |
code_final = torch.cat([code_final, code]) |
|
|
216 |
|
|
|
217 |
current_batch_len = len(survtime) |
|
|
218 |
R_matrix_test = np.zeros([current_batch_len, current_batch_len], dtype=int) |
|
|
219 |
for i in range(current_batch_len): |
|
|
220 |
for j in range(current_batch_len): |
|
|
221 |
R_matrix_test[i,j] = survtime[j] >= survtime[i] |
|
|
222 |
|
|
|
223 |
test_R = torch.FloatTensor(R_matrix_test) |
|
|
224 |
test_R = Variable(test_R) |
|
|
225 |
if cuda: |
|
|
226 |
test_R = test_R.cuda() |
|
|
227 |
test_ystatus = lbl |
|
|
228 |
theta = lbl_pred.reshape(-1) |
|
|
229 |
exp_theta = torch.exp(theta) |
|
|
230 |
loss_nn = -torch.mean( (theta - torch.log(torch.sum( exp_theta*test_R ,dim=1))) * test_ystatus.float() ) |
|
|
231 |
loss_nn_sum = loss_nn_sum + loss_nn.data.item() |
|
|
232 |
iter += 1 |
|
|
233 |
code_final_4_original_data = code_final.data.cpu().numpy() |
|
|
234 |
acc_test = accuracy_cox(lbl_pred_all.data, lbl_all) |
|
|
235 |
pvalue_pred = cox_log_rank(lbl_pred_all.data, lbl_all, survtime_all) |
|
|
236 |
c_index = CIndex_lifeline(lbl_pred_all.data, lbl_all, survtime_all) |
|
|
237 |
if verbose > 0: |
|
|
238 |
print('\n[{:s}]\t\tloss (nn):{:.4f}'.format(whichset, loss_nn_sum), |
|
|
239 |
'c_index: {:.4f}, p-value: {:.3e}'.format(c_index, pvalue_pred)) |
|
|
240 |
return(code_final_4_original_data, loss_nn_sum, acc_test, \ |
|
|
241 |
pvalue_pred, c_index, lbl_pred_all.data.cpu().numpy().reshape(-1), OS_event, survtime_all) |
|
|
242 |
|
|
|
243 |
def init_weights(m): |
|
|
244 |
if type(m) == nn.Linear: |
|
|
245 |
m.weight.data.normal_(0, 0.5) |
|
|
246 |
|
|
|
247 |
def train(datasets, num_epochs, batch_size, learning_rate, dropout_rate, |
|
|
248 |
lambda_1, length_of_data, cuda, measure, verbose): |
|
|
249 |
|
|
|
250 |
|
|
|
251 |
x = datasets['train']['x'] |
|
|
252 |
e = datasets['train']['e'] |
|
|
253 |
t = datasets['train']['t'] |
|
|
254 |
nodes_in = x.shape[1] |
|
|
255 |
|
|
|
256 |
X = torch.FloatTensor(x) |
|
|
257 |
OS_event = torch.LongTensor(e) |
|
|
258 |
OS = torch.FloatTensor(t) |
|
|
259 |
|
|
|
260 |
|
|
|
261 |
dataloader = DataLoader(X, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=False) |
|
|
262 |
lblloader = DataLoader(OS_event, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=False) |
|
|
263 |
OSloader = DataLoader(OS, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=False) |
|
|
264 |
|
|
|
265 |
|
|
|
266 |
|
|
|
267 |
cudnn.deterministic = True |
|
|
268 |
torch.cuda.manual_seed_all(666) |
|
|
269 |
torch.manual_seed(666) |
|
|
270 |
random.seed(666) |
|
|
271 |
|
|
|
272 |
model = SALMON(nodes_in, dropout_rate, length_of_data, label_dim = 1) |
|
|
273 |
|
|
|
274 |
if cuda: |
|
|
275 |
model.cuda() |
|
|
276 |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0) |
|
|
277 |
|
|
|
278 |
c_index_list = {} |
|
|
279 |
c_index_list['train'] = [] |
|
|
280 |
c_index_list['test'] = [] |
|
|
281 |
loss_nn_all = [] |
|
|
282 |
pvalue_all = [] |
|
|
283 |
c_index_all = [] |
|
|
284 |
acc_train_all = [] |
|
|
285 |
c_index_best = 0 |
|
|
286 |
code_output = None |
|
|
287 |
|
|
|
288 |
|
|
|
289 |
for epoch in tqdm(range(num_epochs)): |
|
|
290 |
model.train() |
|
|
291 |
lbl_pred_all = None |
|
|
292 |
lbl_all = None |
|
|
293 |
survtime_all = None |
|
|
294 |
code_final = None |
|
|
295 |
loss_nn_sum = 0 |
|
|
296 |
iter = 0 |
|
|
297 |
gc.collect() |
|
|
298 |
for data, lbl, survtime in zip(dataloader, lblloader, OSloader): |
|
|
299 |
optimizer.zero_grad() # zero the gradient buffer |
|
|
300 |
graph = data |
|
|
301 |
if cuda: |
|
|
302 |
model = model.cuda() |
|
|
303 |
graph = graph.cuda() |
|
|
304 |
lbl = lbl.cuda() |
|
|
305 |
# ===================forward===================== |
|
|
306 |
output, code, lbl_pred = model(graph) |
|
|
307 |
|
|
|
308 |
if iter == 0: |
|
|
309 |
lbl_pred_all = lbl_pred |
|
|
310 |
survtime_all = survtime |
|
|
311 |
lbl_all = lbl |
|
|
312 |
code_final = code |
|
|
313 |
else: |
|
|
314 |
lbl_pred_all = torch.cat([lbl_pred_all, lbl_pred]) |
|
|
315 |
lbl_all = torch.cat([lbl_all, lbl]) |
|
|
316 |
survtime_all = torch.cat([survtime_all, survtime]) |
|
|
317 |
code_final = torch.cat([code_final, code]) |
|
|
318 |
# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet |
|
|
319 |
# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data |
|
|
320 |
current_batch_len = len(survtime) |
|
|
321 |
R_matrix_train = np.zeros([current_batch_len, current_batch_len], dtype=int) |
|
|
322 |
for i in range(current_batch_len): |
|
|
323 |
for j in range(current_batch_len): |
|
|
324 |
R_matrix_train[i,j] = survtime[j] >= survtime[i] |
|
|
325 |
|
|
|
326 |
train_R = torch.FloatTensor(R_matrix_train) |
|
|
327 |
if cuda: |
|
|
328 |
train_R = train_R.cuda() |
|
|
329 |
train_ystatus = lbl |
|
|
330 |
|
|
|
331 |
theta = lbl_pred.reshape(-1) |
|
|
332 |
exp_theta = torch.exp(theta) |
|
|
333 |
|
|
|
334 |
loss_nn = -torch.mean( (theta - torch.log(torch.sum( exp_theta*train_R ,dim=1))) * train_ystatus.float() ) |
|
|
335 |
|
|
|
336 |
l1_reg = None |
|
|
337 |
for W in model.parameters(): |
|
|
338 |
if l1_reg is None: |
|
|
339 |
l1_reg = torch.abs(W).sum() |
|
|
340 |
else: |
|
|
341 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
342 |
|
|
|
343 |
loss = loss_nn + lambda_1 * l1_reg |
|
|
344 |
if verbose > 0: |
|
|
345 |
print("\nloss_nn: %.4f, L1: %.4f" % (loss_nn, lambda_1 * l1_reg)) |
|
|
346 |
loss_nn_sum = loss_nn_sum + loss_nn.data.item() |
|
|
347 |
# ===================backward==================== |
|
|
348 |
loss.backward() |
|
|
349 |
optimizer.step() |
|
|
350 |
|
|
|
351 |
iter += 1 |
|
|
352 |
torch.cuda.empty_cache() |
|
|
353 |
code_final_4_original_data = code_final.data.cpu().numpy() |
|
|
354 |
|
|
|
355 |
if measure or epoch == (num_epochs - 1): |
|
|
356 |
acc_train = accuracy_cox(lbl_pred_all.data, lbl_all) |
|
|
357 |
pvalue_pred = cox_log_rank(lbl_pred_all.data, lbl_all, survtime_all) |
|
|
358 |
c_index = CIndex_lifeline(lbl_pred_all.data, lbl_all, survtime_all) |
|
|
359 |
|
|
|
360 |
c_index_list['train'].append(c_index) |
|
|
361 |
if c_index > c_index_best: |
|
|
362 |
c_index_best = c_index |
|
|
363 |
code_output = code_final_4_original_data |
|
|
364 |
if verbose > 0: |
|
|
365 |
print('\n[Training]\t loss (nn):{:.4f}'.format(loss_nn_sum), |
|
|
366 |
'c_index: {:.4f}, p-value: {:.3e}'.format(c_index, pvalue_pred)) |
|
|
367 |
pvalue_all.append(pvalue_pred) |
|
|
368 |
c_index_all.append(c_index) |
|
|
369 |
loss_nn_all.append(loss_nn_sum) |
|
|
370 |
acc_train_all.append(acc_train) |
|
|
371 |
whichset = 'test' |
|
|
372 |
code_validation, loss_nn_sum, acc_test, pvalue_pred, c_index_pred, lbl_pred_all, OS_event, OS = \ |
|
|
373 |
test(model, datasets, whichset, length_of_data, batch_size, cuda, verbose) |
|
|
374 |
|
|
|
375 |
c_index_list['test'].append(c_index_pred) |
|
|
376 |
return(model, loss_nn_all, pvalue_all, c_index_all, c_index_list, acc_train_all, code_output) |