|
a |
|
b/train_test.py |
|
|
1 |
import random |
|
|
2 |
from tqdm import tqdm |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
import torch.backends.cudnn as cudnn |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
from torch.utils.data import RandomSampler |
|
|
8 |
|
|
|
9 |
from data_loaders import PathgraphomicDatasetLoader, PathgraphomicFastDatasetLoader |
|
|
10 |
from networks import define_net, define_reg, define_optimizer, define_scheduler |
|
|
11 |
from utils import unfreeze_unimodal, CoxLoss, CIndex_lifeline, cox_log_rank, accuracy_cox, mixed_collate, count_parameters |
|
|
12 |
|
|
|
13 |
#from GPUtil import showUtilization as gpu_usage |
|
|
14 |
import pdb |
|
|
15 |
import pickle |
|
|
16 |
import os |
|
|
17 |
|
|
|
18 |
def train(opt, data, device, k): |
|
|
19 |
cudnn.deterministic = True |
|
|
20 |
torch.cuda.manual_seed_all(2019) |
|
|
21 |
torch.manual_seed(2019) |
|
|
22 |
random.seed(2019) |
|
|
23 |
|
|
|
24 |
model = define_net(opt, k) |
|
|
25 |
optimizer = define_optimizer(opt, model) |
|
|
26 |
scheduler = define_scheduler(opt, optimizer) |
|
|
27 |
print(model) |
|
|
28 |
print("Number of Trainable Parameters: %d" % count_parameters(model)) |
|
|
29 |
print("Activation Type:", opt.act_type) |
|
|
30 |
print("Optimizer Type:", opt.optimizer_type) |
|
|
31 |
print("Regularization Type:", opt.reg_type) |
|
|
32 |
|
|
|
33 |
use_patch, roi_dir = ('_patch_', 'all_st_patches_512') if opt.use_vgg_features else ('_', 'all_st') |
|
|
34 |
|
|
|
35 |
custom_data_loader = PathgraphomicFastDatasetLoader(opt, data, split='train', mode=opt.mode) if opt.use_vgg_features else PathgraphomicDatasetLoader(opt, data, split='train', mode=opt.mode) |
|
|
36 |
train_loader = torch.utils.data.DataLoader(dataset=custom_data_loader, batch_size=opt.batch_size, shuffle=True, collate_fn=mixed_collate) |
|
|
37 |
metric_logger = {'train':{'loss':[], 'pvalue':[], 'cindex':[], 'surv_acc':[], 'grad_acc':[]}, |
|
|
38 |
'test':{'loss':[], 'pvalue':[], 'cindex':[], 'surv_acc':[], 'grad_acc':[]}} |
|
|
39 |
|
|
|
40 |
for epoch in tqdm(range(opt.epoch_count, opt.niter+opt.niter_decay+1)): |
|
|
41 |
|
|
|
42 |
if opt.finetune == 1: |
|
|
43 |
unfreeze_unimodal(opt, model, epoch) |
|
|
44 |
|
|
|
45 |
model.train() |
|
|
46 |
risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([]) # Used for calculating the C-Index |
|
|
47 |
loss_epoch, grad_acc_epoch = 0, 0 |
|
|
48 |
|
|
|
49 |
for batch_idx, (x_path, x_grph, x_omic, censor, survtime, grade) in enumerate(train_loader): |
|
|
50 |
|
|
|
51 |
censor = censor.to(device) if "surv" in opt.task else censor |
|
|
52 |
grade = grade.to(device) if "grad" in opt.task else grade |
|
|
53 |
_, pred = model(x_path=x_path.to(device), x_grph=x_grph.to(device), x_omic=x_omic.to(device)) |
|
|
54 |
|
|
|
55 |
loss_cox = CoxLoss(survtime, censor, pred, device) if opt.task == "surv" else 0 |
|
|
56 |
loss_reg = define_reg(opt, model) |
|
|
57 |
loss_nll = F.nll_loss(pred, grade) if opt.task == "grad" else 0 |
|
|
58 |
loss = opt.lambda_cox*loss_cox + opt.lambda_nll*loss_nll + opt.lambda_reg*loss_reg |
|
|
59 |
loss_epoch += loss.data.item() |
|
|
60 |
|
|
|
61 |
optimizer.zero_grad() |
|
|
62 |
loss.backward() |
|
|
63 |
optimizer.step() |
|
|
64 |
|
|
|
65 |
if opt.task == "surv": |
|
|
66 |
risk_pred_all = np.concatenate((risk_pred_all, pred.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
67 |
censor_all = np.concatenate((censor_all, censor.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
68 |
survtime_all = np.concatenate((survtime_all, survtime.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
69 |
elif opt.task == "grad": |
|
|
70 |
pred = pred.argmax(dim=1, keepdim=True) |
|
|
71 |
grad_acc_epoch += pred.eq(grade.view_as(pred)).sum().item() |
|
|
72 |
|
|
|
73 |
if opt.verbose > 0 and opt.print_every > 0 and (batch_idx % opt.print_every == 0 or batch_idx+1 == len(train_loader)): |
|
|
74 |
print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( |
|
|
75 |
epoch+1, opt.niter+opt.niter_decay, batch_idx+1, len(train_loader), loss.item())) |
|
|
76 |
|
|
|
77 |
scheduler.step() |
|
|
78 |
# lr = optimizer.param_groups[0]['lr'] |
|
|
79 |
#print('learning rate = %.7f' % lr) |
|
|
80 |
|
|
|
81 |
if opt.measure or epoch == (opt.niter+opt.niter_decay - 1): |
|
|
82 |
loss_epoch /= len(train_loader) |
|
|
83 |
|
|
|
84 |
cindex_epoch = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None |
|
|
85 |
pvalue_epoch = cox_log_rank(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None |
|
|
86 |
surv_acc_epoch = accuracy_cox(risk_pred_all, censor_all) if opt.task == 'surv' else None |
|
|
87 |
grad_acc_epoch = grad_acc_epoch / len(train_loader.dataset) if opt.task == 'grad' else None |
|
|
88 |
loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test = test(opt, model, data, 'test', device) |
|
|
89 |
|
|
|
90 |
metric_logger['train']['loss'].append(loss_epoch) |
|
|
91 |
metric_logger['train']['cindex'].append(cindex_epoch) |
|
|
92 |
metric_logger['train']['pvalue'].append(pvalue_epoch) |
|
|
93 |
metric_logger['train']['surv_acc'].append(surv_acc_epoch) |
|
|
94 |
metric_logger['train']['grad_acc'].append(grad_acc_epoch) |
|
|
95 |
|
|
|
96 |
metric_logger['test']['loss'].append(loss_test) |
|
|
97 |
metric_logger['test']['cindex'].append(cindex_test) |
|
|
98 |
metric_logger['test']['pvalue'].append(pvalue_test) |
|
|
99 |
metric_logger['test']['surv_acc'].append(surv_acc_test) |
|
|
100 |
metric_logger['test']['grad_acc'].append(grad_acc_test) |
|
|
101 |
|
|
|
102 |
pickle.dump(pred_test, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%s%d_pred_test.pkl' % (opt.model_name, k, use_patch, epoch)), 'wb')) |
|
|
103 |
|
|
|
104 |
if opt.verbose > 0: |
|
|
105 |
if opt.task == 'surv': |
|
|
106 |
print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format('Train', loss_epoch, 'C-Index', cindex_epoch)) |
|
|
107 |
print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format('Test', loss_test, 'C-Index', cindex_test)) |
|
|
108 |
elif opt.task == 'grad': |
|
|
109 |
print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format('Train', loss_epoch, 'Accuracy', grad_acc_epoch)) |
|
|
110 |
print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format('Test', loss_test, 'Accuracy', grad_acc_test)) |
|
|
111 |
|
|
|
112 |
if opt.task == 'grad' and loss_epoch < opt.patience: |
|
|
113 |
print("Early stopping at Epoch %d" % epoch) |
|
|
114 |
break |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
return model, optimizer, metric_logger |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def test(opt, model, data, split, device): |
|
|
121 |
model.eval() |
|
|
122 |
|
|
|
123 |
custom_data_loader = PathgraphomicFastDatasetLoader(opt, data, split, mode=opt.mode) if opt.use_vgg_features else PathgraphomicDatasetLoader(opt, data, split=split, mode=opt.mode) |
|
|
124 |
test_loader = torch.utils.data.DataLoader(dataset=custom_data_loader, batch_size=opt.batch_size, shuffle=False, collate_fn=mixed_collate) |
|
|
125 |
|
|
|
126 |
risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([]) |
|
|
127 |
probs_all, gt_all = None, np.array([]) |
|
|
128 |
loss_test, grad_acc_test = 0, 0 |
|
|
129 |
|
|
|
130 |
for batch_idx, (x_path, x_grph, x_omic, censor, survtime, grade) in enumerate(test_loader): |
|
|
131 |
|
|
|
132 |
censor = censor.to(device) if "surv" in opt.task else censor |
|
|
133 |
grade = grade.to(device) if "grad" in opt.task else grade |
|
|
134 |
_, pred = model(x_path=x_path.to(device), x_grph=x_grph.to(device), x_omic=x_omic.to(device)) |
|
|
135 |
|
|
|
136 |
loss_cox = CoxLoss(survtime, censor, pred, device) if opt.task == "surv" else 0 |
|
|
137 |
loss_reg = define_reg(opt, model) |
|
|
138 |
loss_nll = F.nll_loss(pred, grade) if opt.task == "grad" else 0 |
|
|
139 |
loss = opt.lambda_cox*loss_cox + opt.lambda_nll*loss_nll + opt.lambda_reg*loss_reg |
|
|
140 |
loss_test += loss.data.item() |
|
|
141 |
|
|
|
142 |
gt_all = np.concatenate((gt_all, grade.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
143 |
|
|
|
144 |
if opt.task == "surv": |
|
|
145 |
risk_pred_all = np.concatenate((risk_pred_all, pred.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
146 |
censor_all = np.concatenate((censor_all, censor.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
147 |
survtime_all = np.concatenate((survtime_all, survtime.detach().cpu().numpy().reshape(-1))) # Logging Information |
|
|
148 |
elif opt.task == "grad": |
|
|
149 |
grade_pred = pred.argmax(dim=1, keepdim=True) |
|
|
150 |
grad_acc_test += grade_pred.eq(grade.view_as(grade_pred)).sum().item() |
|
|
151 |
probs_np = pred.detach().cpu().numpy() |
|
|
152 |
probs_all = probs_np if probs_all is None else np.concatenate((probs_all, probs_np), axis=0) # Logging Information |
|
|
153 |
|
|
|
154 |
################################################### |
|
|
155 |
# ==== Measuring Test Loss, C-Index, P-Value ==== # |
|
|
156 |
################################################### |
|
|
157 |
loss_test /= len(test_loader) |
|
|
158 |
cindex_test = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None |
|
|
159 |
pvalue_test = cox_log_rank(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None |
|
|
160 |
surv_acc_test = accuracy_cox(risk_pred_all, censor_all) if opt.task == 'surv' else None |
|
|
161 |
grad_acc_test = grad_acc_test / len(test_loader.dataset) if opt.task == 'grad' else None |
|
|
162 |
pred_test = [risk_pred_all, survtime_all, censor_all, probs_all, gt_all] |
|
|
163 |
|
|
|
164 |
return loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test |