Diff of /train_cv.py [000000] .. [2095ed]

Switch to unified view

a b/train_cv.py
1
import os
2
import logging
3
import numpy as np
4
import random
5
import pickle
6
7
import torch
8
9
# Env
10
from data_loaders import *
11
from options import parse_args
12
from train_test import train, test
13
14
15
### 1. Initializes parser and device
16
opt = parse_args()
17
device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
18
print("Using device:", device)
19
if not os.path.exists(opt.checkpoints_dir): os.makedirs(opt.checkpoints_dir)
20
if not os.path.exists(os.path.join(opt.checkpoints_dir, opt.exp_name)): os.makedirs(os.path.join(opt.checkpoints_dir, opt.exp_name))
21
if not os.path.exists(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name)): os.makedirs(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name))
22
23
### 2. Initializes Data
24
ignore_missing_histype = 1 if 'grad' in opt.task else 0
25
ignore_missing_moltype = 1 if 'omic' in opt.mode else 0
26
use_patch, roi_dir = ('_patch_', 'all_st_patches_512') if opt.use_vgg_features else ('_', 'all_st')
27
use_rnaseq = '_rnaseq' if opt.use_rnaseq else ''
28
29
data_cv_path = '%s/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (opt.dataroot, roi_dir, ignore_missing_moltype, ignore_missing_histype, opt.use_vgg_features, use_rnaseq)
30
print("Loading %s" % data_cv_path)
31
data_cv = pickle.load(open(data_cv_path, 'rb'))
32
data_cv_splits = data_cv['cv_splits']
33
results = []
34
35
### 3. Sets-Up Main Loop
36
for k, data in data_cv_splits.items():
37
    print("*******************************************")
38
    print("************** SPLIT (%d/%d) **************" % (k, len(data_cv_splits.items())))
39
    print("*******************************************")
40
    if os.path.exists(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d_patch_pred_train.pkl' % (opt.model_name, k))):
41
        print("Train-Test Split already made.")
42
        continue
43
44
    ### 3.1 Trains Model
45
    model, optimizer, metric_logger = train(opt, data, device, k)
46
47
    ### 3.2 Evalutes Train + Test Error, and Saves Model
48
    loss_train, cindex_train, pvalue_train, surv_acc_train, grad_acc_train, pred_train = test(opt, model, data, 'train', device)
49
    loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test = test(opt, model, data, 'test', device)
50
51
    if opt.task == 'surv':
52
        print("[Final] Apply model to training set: C-Index: %.10f, P-Value: %.10e" % (cindex_train, pvalue_train))
53
        logging.info("[Final] Apply model to training set: C-Index: %.10f, P-Value: %.10e" % (cindex_train, pvalue_train))
54
        print("[Final] Apply model to testing set: C-Index: %.10f, P-Value: %.10e" % (cindex_test, pvalue_test))
55
        logging.info("[Final] Apply model to testing set: cC-Index: %.10f, P-Value: %.10e" % (cindex_test, pvalue_test))
56
        results.append(cindex_test)
57
    elif opt.task == 'grad':
58
        print("[Final] Apply model to training set: Loss: %.10f, Acc: %.4f" % (loss_train, grad_acc_train))
59
        logging.info("[Final] Apply model to training set: Loss: %.10f, Acc: %.4f" % (loss_train, grad_acc_train))
60
        print("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test))
61
        logging.info("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test))
62
        results.append(grad_acc_test)
63
64
    ### 3.3 Saves Model
65
    if len(opt.gpu_ids) > 0 and torch.cuda.is_available():
66
        model_state_dict = model.module.cpu().state_dict()
67
    else:
68
        model_state_dict = model.cpu().state_dict()
69
    
70
    torch.save({
71
        'split':k,
72
        'opt': opt,
73
        'epoch': opt.niter+opt.niter_decay,
74
        'data': data,
75
        'model_state_dict': model_state_dict,
76
        'optimizer_state_dict': optimizer.state_dict(),
77
        'metrics': metric_logger}, 
78
        os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d.pt' % (opt.model_name, k)))
79
80
    print()
81
82
    pickle.dump(pred_train, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%spred_train.pkl' % (opt.model_name, k, use_patch)), 'wb'))
83
    pickle.dump(pred_test, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%spred_test.pkl' % (opt.model_name, k, use_patch)), 'wb'))
84
85
86
print('Split Results:', results)
87
print("Average:", np.array(results).mean())
88
pickle.dump(results, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_results.pkl' % opt.model_name), 'wb'))