|
a |
|
b/main_mtl_concat.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
import pdb |
|
|
5 |
import os |
|
|
6 |
import math |
|
|
7 |
|
|
|
8 |
# internal imports |
|
|
9 |
from utils.file_utils import save_pkl, load_pkl |
|
|
10 |
from utils.utils import * |
|
|
11 |
from utils.core_utils_mtl_concat import train |
|
|
12 |
from datasets.dataset_mtl_concat import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset |
|
|
13 |
|
|
|
14 |
# pytorch imports |
|
|
15 |
import torch |
|
|
16 |
from torch.utils.data import DataLoader, sampler |
|
|
17 |
import torch.nn as nn |
|
|
18 |
import torch.nn.functional as F |
|
|
19 |
|
|
|
20 |
import pandas as pd |
|
|
21 |
import numpy as np |
|
|
22 |
|
|
|
23 |
def main(args): |
|
|
24 |
# create results directory if necessary |
|
|
25 |
if not os.path.isdir(args.results_dir): |
|
|
26 |
os.mkdir(args.results_dir) |
|
|
27 |
|
|
|
28 |
if args.k_start == -1: |
|
|
29 |
start = 0 |
|
|
30 |
else: |
|
|
31 |
start = args.k_start |
|
|
32 |
if args.k_end == -1: |
|
|
33 |
end = args.k |
|
|
34 |
else: |
|
|
35 |
end = args.k_end |
|
|
36 |
|
|
|
37 |
all_cls_test_auc = [] |
|
|
38 |
all_cls_val_auc = [] |
|
|
39 |
all_cls_test_acc = [] |
|
|
40 |
all_cls_val_acc = [] |
|
|
41 |
|
|
|
42 |
all_site_test_auc = [] |
|
|
43 |
all_site_val_auc = [] |
|
|
44 |
all_site_test_acc = [] |
|
|
45 |
all_site_val_acc = [] |
|
|
46 |
folds = np.arange(start, end) |
|
|
47 |
for i in folds: |
|
|
48 |
seed_torch(args.seed) |
|
|
49 |
train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, |
|
|
50 |
csv_path='{}/splits_{}.csv'.format(args.split_dir, i)) |
|
|
51 |
|
|
|
52 |
print('training: {}, validation: {}, testing: {}'.format(len(train_dataset), len(val_dataset), len(test_dataset))) |
|
|
53 |
datasets = (train_dataset, val_dataset, test_dataset) |
|
|
54 |
results, cls_test_auc, cls_val_auc, cls_test_acc, cls_val_acc, site_test_auc, site_val_auc, site_test_acc, site_val_acc = train(datasets, i, args) |
|
|
55 |
all_cls_test_auc.append(cls_test_auc) |
|
|
56 |
all_cls_val_auc.append(cls_val_auc) |
|
|
57 |
all_cls_test_acc.append(cls_test_acc) |
|
|
58 |
all_cls_val_acc.append(cls_val_acc) |
|
|
59 |
|
|
|
60 |
all_site_test_auc.append(site_test_auc) |
|
|
61 |
all_site_val_auc.append(site_val_auc) |
|
|
62 |
all_site_test_acc.append(site_test_acc) |
|
|
63 |
all_site_val_acc.append(site_val_acc) |
|
|
64 |
#write results to pkl |
|
|
65 |
filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i)) |
|
|
66 |
save_pkl(filename, results) |
|
|
67 |
|
|
|
68 |
final_df = pd.DataFrame({'folds': folds, 'cls_test_auc': all_cls_test_auc, |
|
|
69 |
'cls_val_auc': all_cls_val_auc, 'cls_test_acc': all_cls_test_acc, 'cls_val_acc' : all_cls_val_acc, |
|
|
70 |
'site_test_auc': all_site_test_auc, |
|
|
71 |
'site_val_auc': all_site_val_auc, 'site_test_acc': all_site_test_acc, 'site_val_acc' : all_site_val_acc}) |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
if len(folds) != args.k: |
|
|
75 |
save_name = 'summary_partial_{}_{}.csv'.format(start, end) |
|
|
76 |
else: |
|
|
77 |
save_name = 'summary.csv' |
|
|
78 |
final_df.to_csv(os.path.join(args.results_dir, save_name)) |
|
|
79 |
|
|
|
80 |
# Training settings |
|
|
81 |
parser = argparse.ArgumentParser(description='Configurations for WSI Training') |
|
|
82 |
parser.add_argument('--data_root_dir', type=str, help='data directory') |
|
|
83 |
parser.add_argument('--max_epochs', type=int, default=200, |
|
|
84 |
help='maximum number of epochs to train (default: 200)') |
|
|
85 |
parser.add_argument('--lr', type=float, default=1e-4, |
|
|
86 |
help='learning rate (default: 0.0001)') |
|
|
87 |
parser.add_argument('--reg', type=float, default=1e-5, |
|
|
88 |
help='weight decay (default: 1e-5)') |
|
|
89 |
parser.add_argument('--seed', type=int, default=1, |
|
|
90 |
help='random seed for reproducible experiment (default: 1)') |
|
|
91 |
parser.add_argument('--k', type=int, default=10, help='number of folds (default: 10)') |
|
|
92 |
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)') |
|
|
93 |
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)') |
|
|
94 |
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)') |
|
|
95 |
parser.add_argument('--split_dir', type=str, default=None, |
|
|
96 |
help='manually specify the set of splits to use, ' |
|
|
97 |
+'instead of infering from the task and label_frac argument (default: None)') |
|
|
98 |
parser.add_argument('--log_data', action='store_true', default=False, help='log data using tensorboard') |
|
|
99 |
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool') |
|
|
100 |
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping') |
|
|
101 |
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam') |
|
|
102 |
parser.add_argument('--drop_out', action='store_true', default=False, help='enabel dropout (p=0.25)') |
|
|
103 |
parser.add_argument('--exp_code', type=str, help='experiment code for saving results') |
|
|
104 |
parser.add_argument('--weighted_sample', action='store_true', default=False, help='enable weighted sampling') |
|
|
105 |
parser.add_argument('--task', type=str, choices=['dummy_mtl_concat']) |
|
|
106 |
args = parser.parse_args() |
|
|
107 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
108 |
|
|
|
109 |
def seed_torch(seed=7): |
|
|
110 |
import random |
|
|
111 |
random.seed(seed) |
|
|
112 |
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
113 |
np.random.seed(seed) |
|
|
114 |
torch.manual_seed(seed) |
|
|
115 |
if device.type == 'cuda': |
|
|
116 |
torch.cuda.manual_seed(seed) |
|
|
117 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. |
|
|
118 |
torch.backends.cudnn.benchmark = False |
|
|
119 |
torch.backends.cudnn.deterministic = True |
|
|
120 |
|
|
|
121 |
seed_torch(args.seed) |
|
|
122 |
|
|
|
123 |
encoding_size = 1024 |
|
|
124 |
settings = {'num_splits': args.k, |
|
|
125 |
'k_start': args.k_start, |
|
|
126 |
'k_end': args.k_end, |
|
|
127 |
'task': args.task, |
|
|
128 |
'max_epochs': args.max_epochs, |
|
|
129 |
'results_dir': args.results_dir, |
|
|
130 |
'lr': args.lr, |
|
|
131 |
'experiment': args.exp_code, |
|
|
132 |
'reg': args.reg, |
|
|
133 |
'seed': args.seed, |
|
|
134 |
"use_drop_out": args.drop_out, |
|
|
135 |
'weighted_sample': args.weighted_sample, |
|
|
136 |
'opt': args.opt} |
|
|
137 |
|
|
|
138 |
print('\nLoad Dataset') |
|
|
139 |
|
|
|
140 |
if args.task == 'dummy_mtl_concat': |
|
|
141 |
args.n_classes=18 |
|
|
142 |
dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/dummy_dataset.csv', |
|
|
143 |
data_dir= os.path.join(args.data_root_dir,'DUMMY_DATA_DIR'), |
|
|
144 |
shuffle = False, |
|
|
145 |
seed = args.seed, |
|
|
146 |
print_info = True, |
|
|
147 |
label_dicts = [{'Lung':0, 'Breast':1, 'Colorectal':2, 'Ovarian':3, |
|
|
148 |
'Pancreatobiliary':4, 'Adrenal':5, |
|
|
149 |
'Skin':6, 'Prostate':7, 'Renal':8, 'Bladder':9, |
|
|
150 |
'Esophagogastric':10, 'Thyroid':11, |
|
|
151 |
'Head Neck':12, 'Glioma':13, |
|
|
152 |
'Germ Cell':14, 'Endometrial': 15, |
|
|
153 |
'Cervix': 16, 'Liver': 17}, |
|
|
154 |
{'Primary':0, 'Metastatic':1}, |
|
|
155 |
{'F':0, 'M':1}], |
|
|
156 |
label_cols = ['label', 'site', 'sex'], |
|
|
157 |
patient_strat= False) |
|
|
158 |
else: |
|
|
159 |
raise NotImplementedError |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
|
|
|
163 |
if not os.path.isdir(args.results_dir): |
|
|
164 |
os.mkdir(args.results_dir) |
|
|
165 |
|
|
|
166 |
args.results_dir = os.path.join(args.results_dir, str(args.exp_code) + '_s{}'.format(args.seed)) |
|
|
167 |
if not os.path.isdir(args.results_dir): |
|
|
168 |
os.mkdir(args.results_dir) |
|
|
169 |
|
|
|
170 |
if args.split_dir is None: |
|
|
171 |
args.split_dir = os.path.join('splits', args.task+'_{}'.format(int(100))) |
|
|
172 |
else: |
|
|
173 |
args.split_dir = os.path.join('splits', args.split_dir) |
|
|
174 |
assert os.path.isdir(args.split_dir) |
|
|
175 |
|
|
|
176 |
settings.update({'split_dir': args.split_dir}) |
|
|
177 |
|
|
|
178 |
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f: |
|
|
179 |
print(settings, file=f) |
|
|
180 |
f.close() |
|
|
181 |
|
|
|
182 |
print("################# Settings ###################") |
|
|
183 |
for key, val in settings.items(): |
|
|
184 |
print("{}: {}".format(key, val)) |
|
|
185 |
|
|
|
186 |
if __name__ == "__main__": |
|
|
187 |
results = main(args) |
|
|
188 |
print("finished!") |
|
|
189 |
print("end script") |
|
|
190 |
|
|
|
191 |
|