|
a |
|
b/eval_surv.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
import pdb |
|
|
5 |
import os |
|
|
6 |
import math |
|
|
7 |
import sys |
|
|
8 |
|
|
|
9 |
# internal imports |
|
|
10 |
from utils.file_utils import save_pkl, load_pkl |
|
|
11 |
from utils.utils import * |
|
|
12 |
from utils.core_utils import train, eval_model |
|
|
13 |
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset |
|
|
14 |
from datasets.dataset_survival import Generic_WSI_Survival_Dataset, Generic_MIL_Survival_Dataset |
|
|
15 |
|
|
|
16 |
# pytorch imports |
|
|
17 |
import torch |
|
|
18 |
from torch.utils.data import DataLoader, sampler |
|
|
19 |
import torch.nn as nn |
|
|
20 |
import torch.nn.functional as F |
|
|
21 |
|
|
|
22 |
import pandas as pd |
|
|
23 |
import numpy as np |
|
|
24 |
|
|
|
25 |
from timeit import default_timer as timer |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def main(args): |
|
|
29 |
# create results directory if necessary |
|
|
30 |
if not os.path.isdir(args.results_dir): |
|
|
31 |
os.mkdir(args.results_dir) |
|
|
32 |
|
|
|
33 |
if args.k_start == -1: |
|
|
34 |
start = 0 |
|
|
35 |
else: |
|
|
36 |
start = args.k_start |
|
|
37 |
if args.k_end == -1: |
|
|
38 |
end = args.k |
|
|
39 |
else: |
|
|
40 |
end = args.k_end |
|
|
41 |
|
|
|
42 |
val_cindex = [] |
|
|
43 |
folds = np.arange(start, end) |
|
|
44 |
|
|
|
45 |
for i in folds: |
|
|
46 |
start = timer() |
|
|
47 |
seed_torch(args.seed) |
|
|
48 |
|
|
|
49 |
train_dataset, val_dataset = dataset.return_splits(from_id=False, csv_path='{}/splits_{}.csv'.format(args.split_dir, i)) |
|
|
50 |
|
|
|
51 |
print('training: {}, validation: {}'.format(len(train_dataset), len(val_dataset))) |
|
|
52 |
datasets = (train_dataset, val_dataset) |
|
|
53 |
|
|
|
54 |
if 'omic' in args.mode: |
|
|
55 |
args.omic_input_dim = train_dataset.genomic_features.shape[1] |
|
|
56 |
print("Genomic Dimension", args.omic_input_dim) |
|
|
57 |
|
|
|
58 |
val_latest, cindex_latest = eval_model(datasets, i, args) |
|
|
59 |
val_cindex.append(cindex_latest) |
|
|
60 |
|
|
|
61 |
#write results to pkl |
|
|
62 |
save_pkl(os.path.join(args.results_dir, 'split_val_{}_results.pkl'.format(i)), val_latest) |
|
|
63 |
end = timer() |
|
|
64 |
print('Fold %d Time: %f seconds' % (i, end - start)) |
|
|
65 |
|
|
|
66 |
if len(folds) != args.k: save_name = 'summary_partial_{}_{}.csv'.format(start, end) |
|
|
67 |
else: save_name = 'summary.csv' |
|
|
68 |
results_df = pd.DataFrame({'folds': folds, 'val_cindex': val_cindex}) |
|
|
69 |
results_df.to_csv(os.path.join(args.results_dir, 'summary.csv')) |
|
|
70 |
|
|
|
71 |
# Training settings |
|
|
72 |
parser = argparse.ArgumentParser(description='Configurations for MMF Training') |
|
|
73 |
parser.add_argument('--data_root_dir', type=str, default='/media/ssd1/pan-cancer', help='data directory') |
|
|
74 |
parser.add_argument('--which_splits', type=str, default='5foldcv', help='Path to splits directory.') |
|
|
75 |
parser.add_argument('--split_dir', type=str, help='Set of splits to use for each cancer type.') |
|
|
76 |
parser.add_argument('--mode', type=str, default='omic') |
|
|
77 |
parser.add_argument('--model_type', type=str, default='clam', help='type of model (attention_mil | max_net | mm_attention_mil)') |
|
|
78 |
|
|
|
79 |
parser.add_argument('--max_epochs', type=int, default=20, help='maximum number of epochs to train (default: 20)') |
|
|
80 |
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate (default: 0.0001)') |
|
|
81 |
parser.add_argument('--label_frac', type=float, default=1.0, help='fraction of training labels (default: 1.0)') |
|
|
82 |
parser.add_argument('--bag_weight', type=float, default=0.7, help='clam: weight coefficient for bag-level loss (default: 0.7)') |
|
|
83 |
parser.add_argument('--reg', type=float, default=1e-5, help='weight decay (default: 1e-5)') |
|
|
84 |
parser.add_argument('--seed', type=int, default=1, help='random seed for reproducible experiment (default: 1)') |
|
|
85 |
parser.add_argument('--k', type=int, default=5, help='number of folds (default: 10)') |
|
|
86 |
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)') |
|
|
87 |
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)') |
|
|
88 |
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)') |
|
|
89 |
parser.add_argument('--log_data', action='store_true', default=True, help='log data using tensorboard') |
|
|
90 |
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool') |
|
|
91 |
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping') |
|
|
92 |
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam') |
|
|
93 |
parser.add_argument('--drop_out', action='store_true', default=True, help='enabel dropout (p=0.25)') |
|
|
94 |
parser.add_argument('--inst_loss', type=str, choices=['svm', 'ce', None], default=None, help='instance-level clustering loss function (default: None)') |
|
|
95 |
parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce', 'ce_surv', 'nll_surv', 'cox_surv'], default='nll_surv', help='slide-level classification loss function (default: ce)') |
|
|
96 |
parser.add_argument('--alpha_surv', type=float, default=0.0, help='How much to weigh uncensored patients') |
|
|
97 |
parser.add_argument('--reg_type', type=str, choices=['None', 'omic', 'pathomic'], default='None', help='Reg Type (default: None)') |
|
|
98 |
parser.add_argument('--lambda_reg', type=float, default=1e-4, help='Regularization Strength') |
|
|
99 |
parser.add_argument('--weighted_sample', action='store_true', default=True, help='enable weighted sampling') |
|
|
100 |
parser.add_argument('--model_size_wsi', type=str, default='small', help='Size of AMIL model.') |
|
|
101 |
parser.add_argument('--model_size_omic', type=str, default='small', help='Size of SNN Model.') |
|
|
102 |
parser.add_argument('--gc', type=int, default=1, help='gradient accumulation step') |
|
|
103 |
parser.add_argument('--batch_size', type=int, default=1, help='Batch Size') |
|
|
104 |
parser.add_argument('--gate_path', action='store_true', default=False, help='Enable feature gating in MMF layer.') |
|
|
105 |
parser.add_argument('--gate_omic', action='store_true', default=False, help='Enable feature gating in MMF layer.') |
|
|
106 |
parser.add_argument('--fusion', type=str, default='tensor', help='Which fusion mechanism to use.') |
|
|
107 |
parser.add_argument('--overwrite', action='store_true', default=False, help='Current experiment results already exists. Redo?') |
|
|
108 |
parser.add_argument('--apply_mad', action='store_true', default=True, help='Use genes with median absolute deviation.') |
|
|
109 |
parser.add_argument('--task', type=str, default='survival', help='Which task.') |
|
|
110 |
args = parser.parse_args() |
|
|
111 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
112 |
|
|
|
113 |
|
|
|
114 |
### Creates Custom Experiment Code |
|
|
115 |
exp_code = '_'.join(args.split_dir.split('_')[:2]) |
|
|
116 |
dataset_path = 'dataset_csv' |
|
|
117 |
param_code = '' |
|
|
118 |
|
|
|
119 |
if args.model_type == 'attention_mil': |
|
|
120 |
param_code += 'WSI' |
|
|
121 |
elif args.model_type == 'max_net': |
|
|
122 |
param_code += 'SNN' |
|
|
123 |
elif args.model_type == 'mm_attention_mil' and args.fusion == 'tensor': |
|
|
124 |
param_code += 'MMF' |
|
|
125 |
else: |
|
|
126 |
raise NotImplementedError |
|
|
127 |
|
|
|
128 |
if 'small' in args.model_size_omic: |
|
|
129 |
param_code += 'sm' |
|
|
130 |
|
|
|
131 |
param_code += '_%s' % args.bag_loss |
|
|
132 |
|
|
|
133 |
if 'mm_' in args.model_type: |
|
|
134 |
param_code += '_g' |
|
|
135 |
if args.gate_path: |
|
|
136 |
param_code += '1' |
|
|
137 |
else: |
|
|
138 |
param_code += '0' |
|
|
139 |
|
|
|
140 |
if args.gate_omic: |
|
|
141 |
param_code += '1' |
|
|
142 |
else: |
|
|
143 |
param_code += '0' |
|
|
144 |
|
|
|
145 |
param_code += '_a%s' % str(args.alpha_surv) |
|
|
146 |
|
|
|
147 |
if args.lr != 2e-4: |
|
|
148 |
param_code += '_lr%s' % format(args.lr, '.0e') |
|
|
149 |
|
|
|
150 |
if args.reg_type != 'None': |
|
|
151 |
param_code += '_reg%s' % format(args.lambda_reg, '.0e') |
|
|
152 |
|
|
|
153 |
param_code += '_%s' % args.which_splits.split("_")[0] |
|
|
154 |
|
|
|
155 |
if args.gc != 1: |
|
|
156 |
param_code += '_gc%s' % str(args.gc) |
|
|
157 |
|
|
|
158 |
if args.apply_mad: |
|
|
159 |
param_code += '_mad' |
|
|
160 |
#dataset_path += '_mad' |
|
|
161 |
|
|
|
162 |
args.exp_code = exp_code + "_" + param_code |
|
|
163 |
|
|
|
164 |
### task |
|
|
165 |
if args.task == 'survival': |
|
|
166 |
args.task = '_'.join(args.split_dir.split('_')[:2]) + '_survival' |
|
|
167 |
print("Experiment Name:", exp_code) |
|
|
168 |
|
|
|
169 |
|
|
|
170 |
def seed_torch(seed=7): |
|
|
171 |
import random |
|
|
172 |
random.seed(seed) |
|
|
173 |
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
174 |
np.random.seed(seed) |
|
|
175 |
torch.manual_seed(seed) |
|
|
176 |
if device.type == 'cuda': |
|
|
177 |
torch.cuda.manual_seed(seed) |
|
|
178 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. |
|
|
179 |
torch.backends.cudnn.benchmark = False |
|
|
180 |
torch.backends.cudnn.deterministic = True |
|
|
181 |
|
|
|
182 |
seed_torch(args.seed) |
|
|
183 |
|
|
|
184 |
encoding_size = 1024 |
|
|
185 |
settings = {'num_splits': args.k, |
|
|
186 |
'k_start': args.k_start, |
|
|
187 |
'k_end': args.k_end, |
|
|
188 |
'task': args.task, |
|
|
189 |
'max_epochs': args.max_epochs, |
|
|
190 |
'results_dir': args.results_dir, |
|
|
191 |
'lr': args.lr, |
|
|
192 |
'experiment': args.exp_code, |
|
|
193 |
'reg': args.reg, |
|
|
194 |
'label_frac': args.label_frac, |
|
|
195 |
'inst_loss': args.inst_loss, |
|
|
196 |
'bag_loss': args.bag_loss, |
|
|
197 |
'bag_weight': args.bag_weight, |
|
|
198 |
'seed': args.seed, |
|
|
199 |
'model_type': args.model_type, |
|
|
200 |
'model_size_wsi': args.model_size_wsi, |
|
|
201 |
'model_size_omic': args.model_size_omic, |
|
|
202 |
"use_drop_out": args.drop_out, |
|
|
203 |
'weighted_sample': args.weighted_sample, |
|
|
204 |
'gc': args.gc, |
|
|
205 |
'opt': args.opt} |
|
|
206 |
|
|
|
207 |
print('\nLoad Dataset') |
|
|
208 |
if args.task == 'tcga_blca_survival': |
|
|
209 |
args.n_classes = 4 |
|
|
210 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
211 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
212 |
mode = args.mode, |
|
|
213 |
data_dir= os.path.join(args.data_root_dir, 'tcga_bladder_20x_features'), |
|
|
214 |
shuffle = False, |
|
|
215 |
seed = args.seed, |
|
|
216 |
print_info = True, |
|
|
217 |
patient_strat= False, |
|
|
218 |
n_bins=4, |
|
|
219 |
label_col = 'survival_months', |
|
|
220 |
ignore=[]) |
|
|
221 |
elif args.task == 'tcga_brca_survival': |
|
|
222 |
args.n_classes = 4 |
|
|
223 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
224 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
225 |
mode = args.mode, |
|
|
226 |
data_dir= os.path.join(args.data_root_dir, 'tcga_breast_20x_features'), |
|
|
227 |
shuffle = False, |
|
|
228 |
seed = args.seed, |
|
|
229 |
print_info = True, |
|
|
230 |
patient_strat= False, |
|
|
231 |
n_bins=4, |
|
|
232 |
label_col = 'survival_months', |
|
|
233 |
ignore=[]) |
|
|
234 |
elif args.task == 'tcga_coadread_survival': |
|
|
235 |
args.n_classes = 4 |
|
|
236 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
237 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
238 |
mode = args.mode, |
|
|
239 |
data_dir= os.path.join(args.data_root_dir, 'tcga_coadread_20x_features'), |
|
|
240 |
shuffle = False, |
|
|
241 |
seed = args.seed, |
|
|
242 |
print_info = True, |
|
|
243 |
patient_strat= False, |
|
|
244 |
n_bins=4, |
|
|
245 |
label_col = 'survival_months', |
|
|
246 |
ignore=[]) |
|
|
247 |
elif args.task == 'tcga_gbmlgg_survival': |
|
|
248 |
args.n_classes = 4 |
|
|
249 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/tcga_gbmlgg_all.csv' % dataset_path, |
|
|
250 |
mode = args.mode, |
|
|
251 |
data_dir={'ASTR': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'), |
|
|
252 |
'AASTR': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'), |
|
|
253 |
'ODG': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'), |
|
|
254 |
'OAST': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'), |
|
|
255 |
'AOAST': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'), |
|
|
256 |
'GBM': os.path.join(args.data_root_dir,'tcga_gbm_20x_features'),}, |
|
|
257 |
shuffle = False, |
|
|
258 |
seed = args.seed, |
|
|
259 |
print_info = True, |
|
|
260 |
patient_strat= False, |
|
|
261 |
n_bins=4, |
|
|
262 |
label_col = 'survival_months', |
|
|
263 |
ignore=[]) |
|
|
264 |
elif args.task == 'tcga_hnsc_survival': |
|
|
265 |
args.n_classes = 4 |
|
|
266 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
267 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
268 |
mode = args.mode, |
|
|
269 |
data_dir= os.path.join(args.data_root_dir, 'tcga_hnsc_20x_features'), |
|
|
270 |
shuffle = False, |
|
|
271 |
seed = args.seed, |
|
|
272 |
print_info = True, |
|
|
273 |
patient_strat= False, |
|
|
274 |
n_bins=4, |
|
|
275 |
label_col = 'survival_months', |
|
|
276 |
ignore=[]) |
|
|
277 |
elif args.task == 'tcga_kirc_survival': |
|
|
278 |
args.n_classes = 4 |
|
|
279 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
280 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
281 |
mode = args.mode, |
|
|
282 |
data_dir= os.path.join(args.data_root_dir, 'tcga_kidney_20x_features'), |
|
|
283 |
shuffle = False, |
|
|
284 |
seed = args.seed, |
|
|
285 |
print_info = True, |
|
|
286 |
patient_strat= False, |
|
|
287 |
n_bins=4, |
|
|
288 |
label_col = 'survival_months', |
|
|
289 |
ignore=[]) |
|
|
290 |
elif args.task == 'tcga_kirp_survival': |
|
|
291 |
args.n_classes = 4 |
|
|
292 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
293 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
294 |
mode = args.mode, |
|
|
295 |
data_dir= os.path.join(args.data_root_dir, 'tcga_kidney_20x_features'), |
|
|
296 |
shuffle = False, |
|
|
297 |
seed = args.seed, |
|
|
298 |
print_info = True, |
|
|
299 |
patient_strat= False, |
|
|
300 |
n_bins=4, |
|
|
301 |
label_col = 'survival_months', |
|
|
302 |
ignore=[]) |
|
|
303 |
elif args.task == 'tcga_lihc_survival': |
|
|
304 |
args.n_classes = 4 |
|
|
305 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
306 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
307 |
mode = args.mode, |
|
|
308 |
data_dir= os.path.join(args.data_root_dir, 'tcga_liver_20x_features'), |
|
|
309 |
shuffle = False, |
|
|
310 |
seed = args.seed, |
|
|
311 |
print_info = True, |
|
|
312 |
patient_strat= False, |
|
|
313 |
n_bins=4, |
|
|
314 |
label_col = 'survival_months', |
|
|
315 |
ignore=[]) |
|
|
316 |
elif args.task == 'tcga_luad_survival': |
|
|
317 |
args.n_classes = 4 |
|
|
318 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
319 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
320 |
mode = args.mode, |
|
|
321 |
data_dir= os.path.join(args.data_root_dir, 'tcga_lung_20x_features'), |
|
|
322 |
shuffle = False, |
|
|
323 |
seed = args.seed, |
|
|
324 |
print_info = True, |
|
|
325 |
patient_strat= False, |
|
|
326 |
n_bins=4, |
|
|
327 |
label_col = 'survival_months', |
|
|
328 |
ignore=[]) |
|
|
329 |
elif args.task == 'tcga_lusc_survival': |
|
|
330 |
args.n_classes = 4 |
|
|
331 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
332 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
333 |
mode = args.mode, |
|
|
334 |
data_dir= os.path.join(args.data_root_dir, 'tcga_lung_20x_features'), |
|
|
335 |
shuffle = False, |
|
|
336 |
seed = args.seed, |
|
|
337 |
print_info = True, |
|
|
338 |
patient_strat= False, |
|
|
339 |
n_bins=4, |
|
|
340 |
label_col = 'survival_months', |
|
|
341 |
ignore=[]) |
|
|
342 |
elif args.task == 'tcga_paad_survival': |
|
|
343 |
args.n_classes = 4 |
|
|
344 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
345 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
346 |
mode = args.mode, |
|
|
347 |
data_dir= os.path.join(args.data_root_dir, 'tcga_pancreas_20x_features'), |
|
|
348 |
shuffle = False, |
|
|
349 |
seed = args.seed, |
|
|
350 |
print_info = True, |
|
|
351 |
patient_strat= False, |
|
|
352 |
n_bins=4, |
|
|
353 |
label_col = 'survival_months', |
|
|
354 |
ignore=[]) |
|
|
355 |
elif args.task == 'tcga_skcm_survival': |
|
|
356 |
args.n_classes = 4 |
|
|
357 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
358 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
359 |
mode = args.mode, |
|
|
360 |
data_dir= os.path.join(args.data_root_dir, 'tcga_skin_20x_features'), |
|
|
361 |
shuffle = False, |
|
|
362 |
seed = args.seed, |
|
|
363 |
print_info = True, |
|
|
364 |
patient_strat= False, |
|
|
365 |
n_bins=4, |
|
|
366 |
label_col = 'survival_months', |
|
|
367 |
ignore=[]) |
|
|
368 |
elif args.task == 'tcga_stad_survival': |
|
|
369 |
args.n_classes = 4 |
|
|
370 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
371 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
372 |
mode = args.mode, |
|
|
373 |
data_dir= os.path.join(args.data_root_dir, 'tcga_stomach_20x_features'), |
|
|
374 |
shuffle = False, |
|
|
375 |
seed = args.seed, |
|
|
376 |
print_info = True, |
|
|
377 |
patient_strat= False, |
|
|
378 |
n_bins=4, |
|
|
379 |
label_col = 'survival_months', |
|
|
380 |
ignore=[]) |
|
|
381 |
elif args.task == 'tcga_ucec_survival': |
|
|
382 |
args.n_classes = 4 |
|
|
383 |
proj = '_'.join(args.task.split('_')[:2]) |
|
|
384 |
dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj), |
|
|
385 |
mode = args.mode, |
|
|
386 |
data_dir= os.path.join(args.data_root_dir, 'tcga_endometrial_20x_features'), |
|
|
387 |
shuffle = False, |
|
|
388 |
seed = args.seed, |
|
|
389 |
print_info = True, |
|
|
390 |
patient_strat= False, |
|
|
391 |
n_bins=4, |
|
|
392 |
label_col = 'survival_months', |
|
|
393 |
ignore=[]) |
|
|
394 |
else: |
|
|
395 |
raise NotImplementedError |
|
|
396 |
|
|
|
397 |
if isinstance(dataset, Generic_MIL_Survival_Dataset): |
|
|
398 |
args.task_type ='survival' |
|
|
399 |
else: |
|
|
400 |
raise NotImplementedError |
|
|
401 |
|
|
|
402 |
if not os.path.isdir(args.results_dir): |
|
|
403 |
os.mkdir(args.results_dir) |
|
|
404 |
|
|
|
405 |
### GET RID OF WHICH_SPLITS IF U WANT TO MAKE THE RESULTS FOLDER LESS CLUTTERRED |
|
|
406 |
args.results_dir = os.path.join(args.results_dir, args.which_splits, param_code, str(args.exp_code) + '_s{}'.format(args.seed)) |
|
|
407 |
if not os.path.isdir(args.results_dir): |
|
|
408 |
os.makedirs(args.results_dir) |
|
|
409 |
|
|
|
410 |
if ('summary.csv' in os.listdir(args.results_dir)) and (not args.overwrite): |
|
|
411 |
print("Exp Code <%s> already exists! Exiting script." % args.exp_code) |
|
|
412 |
sys.exit() |
|
|
413 |
|
|
|
414 |
if args.split_dir is None: |
|
|
415 |
args.split_dir = os.path.join('./splits', args.task+'_{}'.format(int(args.label_frac*100))) |
|
|
416 |
else: |
|
|
417 |
args.split_dir = os.path.join('./splits', args.which_splits, args.split_dir) |
|
|
418 |
|
|
|
419 |
print("split_dir", args.split_dir) |
|
|
420 |
|
|
|
421 |
assert os.path.isdir(args.split_dir) |
|
|
422 |
|
|
|
423 |
settings.update({'split_dir': args.split_dir}) |
|
|
424 |
|
|
|
425 |
|
|
|
426 |
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f: |
|
|
427 |
print(settings, file=f) |
|
|
428 |
f.close() |
|
|
429 |
|
|
|
430 |
print("################# Settings ###################") |
|
|
431 |
for key, val in settings.items(): |
|
|
432 |
print("{}: {}".format(key, val)) |
|
|
433 |
|
|
|
434 |
if __name__ == "__main__": |
|
|
435 |
|
|
|
436 |
start = timer() |
|
|
437 |
results = main(args) |
|
|
438 |
end = timer() |
|
|
439 |
print("finished!") |
|
|
440 |
print("end script") |
|
|
441 |
print('Script Time: %f seconds' % (end - start)) |