|
a |
|
b/train_cv.py |
|
|
1 |
import os |
|
|
2 |
import logging |
|
|
3 |
import numpy as np |
|
|
4 |
import random |
|
|
5 |
import pickle |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
|
|
|
9 |
# Env |
|
|
10 |
from data_loaders import * |
|
|
11 |
from options import parse_args |
|
|
12 |
from train_test import train, test |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
### 1. Initializes parser and device |
|
|
16 |
opt = parse_args() |
|
|
17 |
device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') |
|
|
18 |
print("Using device:", device) |
|
|
19 |
if not os.path.exists(opt.checkpoints_dir): os.makedirs(opt.checkpoints_dir) |
|
|
20 |
if not os.path.exists(os.path.join(opt.checkpoints_dir, opt.exp_name)): os.makedirs(os.path.join(opt.checkpoints_dir, opt.exp_name)) |
|
|
21 |
if not os.path.exists(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name)): os.makedirs(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name)) |
|
|
22 |
|
|
|
23 |
### 2. Initializes Data |
|
|
24 |
ignore_missing_histype = 1 if 'grad' in opt.task else 0 |
|
|
25 |
ignore_missing_moltype = 1 if 'omic' in opt.mode else 0 |
|
|
26 |
use_patch, roi_dir = ('_patch_', 'all_st_patches_512') if opt.use_vgg_features else ('_', 'all_st') |
|
|
27 |
use_rnaseq = '_rnaseq' if opt.use_rnaseq else '' |
|
|
28 |
|
|
|
29 |
data_cv_path = '%s/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (opt.dataroot, roi_dir, ignore_missing_moltype, ignore_missing_histype, opt.use_vgg_features, use_rnaseq) |
|
|
30 |
print("Loading %s" % data_cv_path) |
|
|
31 |
data_cv = pickle.load(open(data_cv_path, 'rb')) |
|
|
32 |
data_cv_splits = data_cv['cv_splits'] |
|
|
33 |
results = [] |
|
|
34 |
|
|
|
35 |
### 3. Sets-Up Main Loop |
|
|
36 |
for k, data in data_cv_splits.items(): |
|
|
37 |
print("*******************************************") |
|
|
38 |
print("************** SPLIT (%d/%d) **************" % (k, len(data_cv_splits.items()))) |
|
|
39 |
print("*******************************************") |
|
|
40 |
if os.path.exists(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d_patch_pred_train.pkl' % (opt.model_name, k))): |
|
|
41 |
print("Train-Test Split already made.") |
|
|
42 |
continue |
|
|
43 |
|
|
|
44 |
### 3.1 Trains Model |
|
|
45 |
model, optimizer, metric_logger = train(opt, data, device, k) |
|
|
46 |
|
|
|
47 |
### 3.2 Evalutes Train + Test Error, and Saves Model |
|
|
48 |
loss_train, cindex_train, pvalue_train, surv_acc_train, grad_acc_train, pred_train = test(opt, model, data, 'train', device) |
|
|
49 |
loss_test, cindex_test, pvalue_test, surv_acc_test, grad_acc_test, pred_test = test(opt, model, data, 'test', device) |
|
|
50 |
|
|
|
51 |
if opt.task == 'surv': |
|
|
52 |
print("[Final] Apply model to training set: C-Index: %.10f, P-Value: %.10e" % (cindex_train, pvalue_train)) |
|
|
53 |
logging.info("[Final] Apply model to training set: C-Index: %.10f, P-Value: %.10e" % (cindex_train, pvalue_train)) |
|
|
54 |
print("[Final] Apply model to testing set: C-Index: %.10f, P-Value: %.10e" % (cindex_test, pvalue_test)) |
|
|
55 |
logging.info("[Final] Apply model to testing set: cC-Index: %.10f, P-Value: %.10e" % (cindex_test, pvalue_test)) |
|
|
56 |
results.append(cindex_test) |
|
|
57 |
elif opt.task == 'grad': |
|
|
58 |
print("[Final] Apply model to training set: Loss: %.10f, Acc: %.4f" % (loss_train, grad_acc_train)) |
|
|
59 |
logging.info("[Final] Apply model to training set: Loss: %.10f, Acc: %.4f" % (loss_train, grad_acc_train)) |
|
|
60 |
print("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test)) |
|
|
61 |
logging.info("[Final] Apply model to testing set: Loss: %.10f, Acc: %.4f" % (loss_test, grad_acc_test)) |
|
|
62 |
results.append(grad_acc_test) |
|
|
63 |
|
|
|
64 |
### 3.3 Saves Model |
|
|
65 |
if len(opt.gpu_ids) > 0 and torch.cuda.is_available(): |
|
|
66 |
model_state_dict = model.module.cpu().state_dict() |
|
|
67 |
else: |
|
|
68 |
model_state_dict = model.cpu().state_dict() |
|
|
69 |
|
|
|
70 |
torch.save({ |
|
|
71 |
'split':k, |
|
|
72 |
'opt': opt, |
|
|
73 |
'epoch': opt.niter+opt.niter_decay, |
|
|
74 |
'data': data, |
|
|
75 |
'model_state_dict': model_state_dict, |
|
|
76 |
'optimizer_state_dict': optimizer.state_dict(), |
|
|
77 |
'metrics': metric_logger}, |
|
|
78 |
os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d.pt' % (opt.model_name, k))) |
|
|
79 |
|
|
|
80 |
print() |
|
|
81 |
|
|
|
82 |
pickle.dump(pred_train, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%spred_train.pkl' % (opt.model_name, k, use_patch)), 'wb')) |
|
|
83 |
pickle.dump(pred_test, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%d%spred_test.pkl' % (opt.model_name, k, use_patch)), 'wb')) |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
print('Split Results:', results) |
|
|
87 |
print("Average:", np.array(results).mean()) |
|
|
88 |
pickle.dump(results, open(os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_results.pkl' % opt.model_name), 'wb')) |