--- a +++ b/train_test.py @@ -0,0 +1,164 @@ +import random +from tqdm import tqdm +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +from torch.utils.data import RandomSampler + +from data_loaders import PathgraphomicDatasetLoader, PathgraphomicFastDatasetLoader +from networks import define_net, define_reg, define_optimizer, define_scheduler +from utils import unfreeze_unimodal, CoxLoss, CIndex_lifeline, cox_log_rank, accuracy_cox, mixed_collate, count_parameters + +#from GPUtil import showUtilization as gpu_usage +import pdb +import pickle +import os + +def train(opt, data, device, k): + cudnn.deterministic = True + torch.cuda.manual_seed_all(2019) + torch.manual_seed(2019) + random.seed(2019) + + model = define_net(opt, k) + optimizer = define_optimizer(opt, model) + scheduler = define_scheduler(opt, optimizer) + print(model) + print("Number of Trainable Parameters: %d" % count_parameters(model)) + print("Activation Type:", opt.act_type) + print("Optimizer Type:", opt.optimizer_type) + print("Regularization Type:", opt.reg_type) + + use_patch, roi_dir = ('_patch_', 'all_st_patches_512') if opt.use_vgg_features else ('_', 'all_st') + + custom_data_loader = PathgraphomicFastDatasetLoader(opt, data, split='train', mode=opt.mode) if opt.use_vgg_features else PathgraphomicDatasetLoader(opt, data, split='train', mode=opt.mode) + train_loader = torch.utils.data.DataLoader(dataset=custom_data_loader, batch_size=opt.batch_size, shuffle=True, collate_fn=mixed_collate) + metric_logger = {'train':{'loss':[], 'pvalue':[], 'cindex':[], 'surv_acc':[], 'grad_acc':[]}, + 'test':{'loss':[], 'pvalue':[], 'cindex':[], 'surv_acc':[], 'grad_acc':[]}} + + for epoch in tqdm(range(opt.epoch_count, opt.niter+opt.niter_decay+1)): + + if opt.finetune == 1: + unfreeze_unimodal(opt, model, epoch) + + model.train() + risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([]) # Used for calculating the C-Index + loss_epoch, grad_acc_epoch = 0, 0 + + for batch_idx, (x_path, x_grph, x_omic, censor, survtime, grade) in enumerate(train_loader): + + censor = censor.to(device) if "surv" in opt.task else censor + grade = grade.to(device) if "grad" in opt.task else grade + _, pred = model(x_path=x_path.to(device), x_grph=x_grph.to(device), x_omic=x_omic.to(device)) + + loss_cox = CoxLoss(survtime, censor, pred, device) if opt.task == "surv" else 0 + loss_reg = define_reg(opt, model) + loss_nll = F.nll_loss(pred, grade) if opt.task == "grad" else 0 + loss = opt.lambda_cox*loss_cox + opt.lambda_nll*loss_nll + opt.lambda_reg*loss_reg + loss_epoch += loss.data.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if opt.task == "surv": + risk_pred_all = np.concatenate((risk_pred_all, pred.detach().cpu().numpy().reshape(-1))) # Logging Information + censor_all = np.concatenate((censor_all, censor.detach().cpu().numpy().reshape(-1))) # Logging Information + survtime_all = np.concatenate((survtime_all, survtime.detach().cpu().numpy().reshape(-1))) # Logging Information + elif opt.task == "grad": + pred = pred.argmax(dim=1, keepdim=True) + grad_acc_epoch += pred.eq(grade.view_as(pred)).sum().item() + + if opt.verbose > 0 and opt.print_every > 0 and (batch_idx % opt.print_every == 0 or batch_idx+1 == len(train_loader)): + print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( + epoch+1, opt.niter+opt.niter_decay, batch_idx+1, len(train_loader), loss.item())) + + scheduler.step() + # lr = optimizer.param_groups[0]['lr'] + #print('learning rate = %.7f' % lr) + + if opt.measure or epoch == (opt.niter+opt.niter_decay - 1): + loss_epoch /= len(train_loader) + + cindex_epoch = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None + pvalue_epoch = cox_log_rank(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None + surv_acc_epoch = accuracy_cox(risk_pred_all, censor_all) if opt.task == 'surv' else None + grad_acc_epoch = grad_acc_epoch / len(train_loader.dataset) if opt.task == 'grad' else None + loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test = test(opt, model, data, 'test', device) + + metric_logger['train']['loss'].append(loss_epoch) + metric_logger['train']['cindex'].append(cindex_epoch) + metric_logger['train']['pvalue'].append(pvalue_epoch) + metric_logger['train']['surv_acc'].append(surv_acc_epoch) + metric_logger['train']['grad_acc'].append(grad_acc_epoch) + + metric_logger['test']['loss'].append(loss_test) + metric_logger['test']['cindex'].append(cindex_test) + metric_logger['test']['pvalue'].append(pvalue_test) + metric_logger['test']['surv_acc'].append(surv_acc_test) + metric_logger['test']['grad_acc'].append(grad_acc_test) + + pickle.dump(pred_test, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%s%d_pred_test.pkl' % (opt.model_name, k, use_patch, epoch)), 'wb')) + + if opt.verbose > 0: + if opt.task == 'surv': + print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format('Train', loss_epoch, 'C-Index', cindex_epoch)) + print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format('Test', loss_test, 'C-Index', cindex_test)) + elif opt.task == 'grad': + print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}'.format('Train', loss_epoch, 'Accuracy', grad_acc_epoch)) + print('[{:s}]\t\tLoss: {:.4f}, {:s}: {:.4f}\n'.format('Test', loss_test, 'Accuracy', grad_acc_test)) + + if opt.task == 'grad' and loss_epoch < opt.patience: + print("Early stopping at Epoch %d" % epoch) + break + + + return model, optimizer, metric_logger + + +def test(opt, model, data, split, device): + model.eval() + + custom_data_loader = PathgraphomicFastDatasetLoader(opt, data, split, mode=opt.mode) if opt.use_vgg_features else PathgraphomicDatasetLoader(opt, data, split=split, mode=opt.mode) + test_loader = torch.utils.data.DataLoader(dataset=custom_data_loader, batch_size=opt.batch_size, shuffle=False, collate_fn=mixed_collate) + + risk_pred_all, censor_all, survtime_all = np.array([]), np.array([]), np.array([]) + probs_all, gt_all = None, np.array([]) + loss_test, grad_acc_test = 0, 0 + + for batch_idx, (x_path, x_grph, x_omic, censor, survtime, grade) in enumerate(test_loader): + + censor = censor.to(device) if "surv" in opt.task else censor + grade = grade.to(device) if "grad" in opt.task else grade + _, pred = model(x_path=x_path.to(device), x_grph=x_grph.to(device), x_omic=x_omic.to(device)) + + loss_cox = CoxLoss(survtime, censor, pred, device) if opt.task == "surv" else 0 + loss_reg = define_reg(opt, model) + loss_nll = F.nll_loss(pred, grade) if opt.task == "grad" else 0 + loss = opt.lambda_cox*loss_cox + opt.lambda_nll*loss_nll + opt.lambda_reg*loss_reg + loss_test += loss.data.item() + + gt_all = np.concatenate((gt_all, grade.detach().cpu().numpy().reshape(-1))) # Logging Information + + if opt.task == "surv": + risk_pred_all = np.concatenate((risk_pred_all, pred.detach().cpu().numpy().reshape(-1))) # Logging Information + censor_all = np.concatenate((censor_all, censor.detach().cpu().numpy().reshape(-1))) # Logging Information + survtime_all = np.concatenate((survtime_all, survtime.detach().cpu().numpy().reshape(-1))) # Logging Information + elif opt.task == "grad": + grade_pred = pred.argmax(dim=1, keepdim=True) + grad_acc_test += grade_pred.eq(grade.view_as(grade_pred)).sum().item() + probs_np = pred.detach().cpu().numpy() + probs_all = probs_np if probs_all is None else np.concatenate((probs_all, probs_np), axis=0) # Logging Information + + ################################################### + # ==== Measuring Test Loss, C-Index, P-Value ==== # + ################################################### + loss_test /= len(test_loader) + cindex_test = CIndex_lifeline(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None + pvalue_test = cox_log_rank(risk_pred_all, censor_all, survtime_all) if opt.task == 'surv' else None + surv_acc_test = accuracy_cox(risk_pred_all, censor_all) if opt.task == 'surv' else None + grad_acc_test = grad_acc_test / len(test_loader.dataset) if opt.task == 'grad' else None + pred_test = [risk_pred_all, survtime_all, censor_all, probs_all, gt_all] + + return loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test