Diff of /train_test.py [000000] .. [2095ed]

Switch to unified view

a b/train_test.py
1
import random
2
from tqdm import tqdm
3
import numpy as np
4
import torch
5
import torch.backends.cudnn as cudnn
6
import torch.nn.functional as F
7
from torch.utils.data import RandomSampler
8
9
from data_loaders import PathgraphomicDatasetLoader, PathgraphomicFastDatasetLoader
10
from networks import define_net, define_reg, define_optimizer, define_scheduler
11
from utils import unfreeze_unimodal, CoxLoss, CIndex_lifeline, cox_log_rank, accuracy_cox, mixed_collate, count_parameters
12
13
#from GPUtil import showUtilization as gpu_usage
14
import pdb
15
import pickle
16
import os
17
18
def train(opt, data, device, k):
19
    cudnn.deterministic = True
20
    torch.cuda.manual_seed_all(2019)
21
    torch.manual_seed(2019)
22
    random.seed(2019)
23
    
24
    model = define_net(opt, k)
25
    optimizer = define_optimizer(opt, model)
26
    scheduler = define_scheduler(opt, optimizer)
27
    print(model)
28
    print("Number of Trainable Parameters: %d" % count_parameters(model))
29
    print("Activation Type:", opt.act_type)
30
    print("Optimizer Type:", opt.optimizer_type)
31
    print("Regularization Type:", opt.reg_type)
32
33
    use_patch, roi_dir = ('_patch_', 'all_st_patches_512') if opt.use_vgg_features else ('_', 'all_st')
34
35
    custom_data_loader = PathgraphomicFastDatasetLoader(opt, data, split='train', mode=opt.mode) if opt.use_vgg_features else PathgraphomicDatasetLoader(opt, data, split='train', mode=opt.mode)
36
    train_loader = torch.utils.data.DataLoader(dataset=custom_data_loader, batch_size=opt.batch_size, shuffle=True, collate_fn=mixed_collate)
37
    metric_logger = {'train':{'loss':[], 'pvalue':[], 'cindex':[], 'surv_acc':[], 'grad_acc':[]},
38
                      'test':{'loss':[], 'pvalue':[], 'cindex':[], 'surv_acc':[], 'grad_acc':[]}}
39
    
40
    for epoch in tqdm(range(opt.epoch_count, opt.niter+opt.niter_decay+1)):
41
42
        if opt.finetune == 1:
43
            unfreeze_unimodal(opt, model, epoch)
44
45
        model.train()
46
        risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([])    # Used for calculating the C-Index
47
        loss_epoch, grad_acc_epoch = 0, 0
48
49
        for batch_idx, (x_path, x_grph, x_omic, censor, survtime, grade) in enumerate(train_loader):
50
51
            censor = censor.to(device) if "surv" in opt.task else censor
52
            grade = grade.to(device) if "grad" in opt.task else grade
53
            _, pred = model(x_path=x_path.to(device), x_grph=x_grph.to(device), x_omic=x_omic.to(device))
54
55
            loss_cox = CoxLoss(survtime, censor, pred, device) if opt.task == "surv" else 0
56
            loss_reg = define_reg(opt, model)
57
            loss_nll = F.nll_loss(pred, grade) if opt.task == "grad" else 0
58
            loss = opt.lambda_cox*loss_cox + opt.lambda_nll*loss_nll + opt.lambda_reg*loss_reg
59
            loss_epoch += loss.data.item()
60
61
            optimizer.zero_grad()
62
            loss.backward()
63
            optimizer.step()
64
            
65
            if opt.task == "surv":
66
                risk_pred_all = np.concatenate((risk_pred_all, pred.detach().cpu().numpy().reshape(-1)))   # Logging Information
67
                censor_all = np.concatenate((censor_all, censor.detach().cpu().numpy().reshape(-1)))   # Logging Information
68
                survtime_all = np.concatenate((survtime_all, survtime.detach().cpu().numpy().reshape(-1)))   # Logging Information
69
            elif opt.task == "grad":
70
                pred = pred.argmax(dim=1, keepdim=True)
71
                grad_acc_epoch += pred.eq(grade.view_as(pred)).sum().item()
72
            
73
            if opt.verbose > 0 and opt.print_every > 0 and (batch_idx % opt.print_every == 0 or batch_idx+1 == len(train_loader)):
74
                print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
75
                    epoch+1, opt.niter+opt.niter_decay, batch_idx+1, len(train_loader), loss.item()))
76
77
        scheduler.step()
78
        # lr = optimizer.param_groups[0]['lr']
79
        #print('learning rate = %.7f' % lr)
80
81
        if opt.measure or epoch == (opt.niter+opt.niter_decay - 1):
82
            loss_epoch /= len(train_loader)
83
84
            cindex_epoch = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None
85
            pvalue_epoch = cox_log_rank(risk_pred_all, censor_all, survtime_all)  if opt.task == 'surv' else None
86
            surv_acc_epoch = accuracy_cox(risk_pred_all, censor_all)  if opt.task == 'surv' else None
87
            grad_acc_epoch = grad_acc_epoch / len(train_loader.dataset) if opt.task == 'grad' else None
88
            loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test = test(opt, model, data, 'test', device)
89
90
            metric_logger['train']['loss'].append(loss_epoch)
91
            metric_logger['train']['cindex'].append(cindex_epoch)
92
            metric_logger['train']['pvalue'].append(pvalue_epoch)
93
            metric_logger['train']['surv_acc'].append(surv_acc_epoch)
94
            metric_logger['train']['grad_acc'].append(grad_acc_epoch)
95
96
            metric_logger['test']['loss'].append(loss_test)
97
            metric_logger['test']['cindex'].append(cindex_test)
98
            metric_logger['test']['pvalue'].append(pvalue_test)
99
            metric_logger['test']['surv_acc'].append(surv_acc_test)
100
            metric_logger['test']['grad_acc'].append(grad_acc_test)
101
102
            pickle.dump(pred_test, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%s%d_pred_test.pkl' % (opt.model_name, k, use_patch, epoch)), 'wb'))
103
104
            if opt.verbose > 0:
105
                if opt.task == 'surv':
106
                    print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format('Train', loss_epoch, 'C-Index', cindex_epoch))
107
                    print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format('Test', loss_test, 'C-Index', cindex_test))
108
                elif opt.task == 'grad':
109
                    print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format('Train', loss_epoch, 'Accuracy', grad_acc_epoch))
110
                    print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format('Test', loss_test, 'Accuracy', grad_acc_test))
111
112
            if opt.task == 'grad' and loss_epoch < opt.patience:
113
                print("Early stopping at Epoch %d" % epoch)
114
                break
115
116
117
    return model, optimizer, metric_logger
118
119
120
def test(opt, model, data, split, device):
121
    model.eval()
122
123
    custom_data_loader = PathgraphomicFastDatasetLoader(opt, data, split, mode=opt.mode) if opt.use_vgg_features else PathgraphomicDatasetLoader(opt, data, split=split, mode=opt.mode)
124
    test_loader = torch.utils.data.DataLoader(dataset=custom_data_loader, batch_size=opt.batch_size, shuffle=False, collate_fn=mixed_collate)
125
    
126
    risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([])
127
    probs_all, gt_all = None, np.array([])
128
    loss_test, grad_acc_test = 0, 0
129
130
    for batch_idx, (x_path, x_grph, x_omic, censor, survtime, grade) in enumerate(test_loader):
131
132
        censor = censor.to(device) if "surv" in opt.task else censor
133
        grade = grade.to(device) if "grad" in opt.task else grade
134
        _, pred = model(x_path=x_path.to(device), x_grph=x_grph.to(device), x_omic=x_omic.to(device))
135
136
        loss_cox = CoxLoss(survtime, censor, pred, device) if opt.task == "surv" else 0
137
        loss_reg = define_reg(opt, model)
138
        loss_nll = F.nll_loss(pred, grade) if opt.task == "grad" else 0
139
        loss = opt.lambda_cox*loss_cox + opt.lambda_nll*loss_nll + opt.lambda_reg*loss_reg
140
        loss_test += loss.data.item()
141
142
        gt_all = np.concatenate((gt_all, grade.detach().cpu().numpy().reshape(-1)))   # Logging Information
143
144
        if opt.task == "surv":
145
            risk_pred_all = np.concatenate((risk_pred_all, pred.detach().cpu().numpy().reshape(-1)))   # Logging Information
146
            censor_all = np.concatenate((censor_all, censor.detach().cpu().numpy().reshape(-1)))   # Logging Information
147
            survtime_all = np.concatenate((survtime_all, survtime.detach().cpu().numpy().reshape(-1)))   # Logging Information
148
        elif opt.task == "grad":
149
            grade_pred = pred.argmax(dim=1, keepdim=True)
150
            grad_acc_test += grade_pred.eq(grade.view_as(grade_pred)).sum().item()
151
            probs_np = pred.detach().cpu().numpy()
152
            probs_all = probs_np if probs_all is None else np.concatenate((probs_all, probs_np), axis=0)   # Logging Information
153
    
154
    ################################################### 
155
    # ==== Measuring Test Loss, C-Index, P-Value ==== #
156
    ###################################################
157
    loss_test /= len(test_loader)
158
    cindex_test = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None
159
    pvalue_test = cox_log_rank(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None
160
    surv_acc_test = accuracy_cox(risk_pred_all, censor_all) if opt.task == 'surv' else None
161
    grad_acc_test = grad_acc_test / len(test_loader.dataset) if opt.task == 'grad' else None
162
    pred_test = [risk_pred_all, survtime_all, censor_all, probs_all, gt_all]
163
164
    return loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test