[286bfb]: / src / training / main_survival.py

Download this file

290 lines (244 with data), 12.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
Main entry point for survival downstream tasks
"""
from __future__ import print_function
import argparse
import pdb
import os
from os.path import join as j_
import sys
# internal imports
from utils.file_utils import save_pkl
from utils.utils import (seed_torch, array2list, merge_dict, read_splits,
parse_model_name, get_current_time, extract_patching_info)
from .trainer import train
from wsi_datasets import WSISurvivalDataset
# pytorch imports
import torch
from torch.utils.data import DataLoader
import pandas as pd
import json
PROTO_MODELS = ['PANTHER', 'OT', 'H2T', 'ProtoCount']
def build_datasets(csv_splits, model_type, batch_size=1, num_workers=2, train_kwargs={}, val_kwargs={}):
"""
Construct dataloaders from the data splits
"""
dataset_splits = {}
label_bins = None
for k in csv_splits.keys(): # ['train', 'val', 'test']
df = csv_splits[k]
dataset_kwargs = train_kwargs.copy() if (k == 'train') else val_kwargs.copy()
dataset_kwargs['label_bins'] = label_bins
dataset = WSISurvivalDataset(df, **dataset_kwargs)
# If prototype methods, each WSI will have same feature bag dimension and is batchable
# Otherwise, we need to use batch size of 1 to accommodate to different bag size for each WSI.
# Alternatively, we can sample same number of patch features per WSI to have larger batch.
if model_type not in PROTO_MODELS:
batch_size = batch_size if dataset_kwargs.get('bag_size', -1) > 0 else 1
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=dataset_kwargs['shuffle'], num_workers=num_workers)
dataset_splits[k] = dataloader
print(f'split: {k}, n: {len(dataset)}')
if (args.loss_fn == 'nll') and (k == 'train'):
label_bins = dataset.get_label_bins()
return dataset_splits
def main(args):
if args.train_bag_size == -1:
args.train_bag_size = args.bag_size
if args.val_bag_size == -1:
args.val_bag_size = args.bag_size
if args.loss_fn != 'nll':
args.n_label_bins = 0
censorship_col = args.target_col.split('_')[0] + '_censorship'
train_kwargs = dict(data_source=args.data_source,
survival_time_col=args.target_col,
censorship_col=censorship_col,
n_label_bins=args.n_label_bins,
label_bins=None,
bag_size=args.train_bag_size,
shuffle=True,
)
# use the whole bag at test time
val_kwargs = dict(data_source=args.data_source,
survival_time_col=args.target_col,
censorship_col=censorship_col,
n_label_bins=args.n_label_bins,
label_bins=None,
bag_size=args.val_bag_size,
shuffle=False,
)
all_results, all_dumps = {}, {}
# Cross-validation
seed_torch(args.seed)
csv_splits = read_splits(args)
print('successfully read splits for: ', list(csv_splits.keys()))
dataset_splits = build_datasets(csv_splits,
model_type=args.model_type,
batch_size=args.batch_size,
num_workers=args.num_workers,
train_kwargs=train_kwargs,
val_kwargs=val_kwargs)
fold_results, fold_dumps = train(dataset_splits, args, mode='survival')
for split, split_results in fold_results.items():
all_results[split] = merge_dict({}, split_results) if (split not in all_results.keys()) else merge_dict(all_results[split], split_results)
save_pkl(j_(args.results_dir, f'{split}_results.pkl'), fold_dumps[split]) # saves per-split, per-fold results to pkl
final_dict = {}
for split, split_results in all_results.items():
final_dict.update({f'{metric}_{split}': array2list(val) for metric, val in split_results.items()})
final_df = pd.DataFrame(final_dict)
save_name = 'summary.csv'
final_df.to_csv(j_(args.results_dir, save_name), index=False)
with open(j_(args.results_dir, save_name + '.json'), 'w') as f:
f.write(json.dumps(final_dict, sort_keys=True, indent=4))
dump_path = j_(args.results_dir, 'all_dumps.h5')
save_pkl(dump_path, fold_dumps)
return final_dict
# Generic training settings
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
### optimizer settings ###
parser.add_argument('--max_epochs', type=int, default=20,
help='maximum number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--wd', type=float, default=1e-5,
help='weight decay')
parser.add_argument('--accum_steps', type=int, default=1,
help='grad accumulation steps')
parser.add_argument('--opt', type=str, default='adamW',
choices=['adamW', 'sgd', 'RAdam'])
parser.add_argument('--lr_scheduler', type=str,
choices=['cosine', 'linear', 'constant'], default='constant')
parser.add_argument('--warmup_steps', type=int,
default=-1, help='warmup iterations')
parser.add_argument('--warmup_epochs', type=int,
default=-1, help='warmup epochs')
parser.add_argument('--batch_size', type=int, default=1)
### misc ###
parser.add_argument('--print_every', default=100,
type=int, help='how often to print')
parser.add_argument('--seed', type=int, default=1,
help='random seed for reproducible experiment (default: 1)')
parser.add_argument('--num_workers', type=int, default=2)
### Earlystopper args ###
parser.add_argument('--early_stopping', type=int,
default=0, help='enable early stopping')
parser.add_argument('--es_min_epochs', type=int, default=3,
help='early stopping min epochs')
parser.add_argument('--es_patience', type=int, default=5,
help='early stopping min patience')
parser.add_argument('--es_metric', type=str, default='loss',
help='early stopping metric')
### model args ###
parser.add_argument('--model_type', type=str, choices=['H2T', 'ABMIL', 'TransMIL', 'SumMIL', 'OT', 'PANTHER', 'ProtoCount', 'DeepAttnMIL', 'ILRA'],
default='ABMIL',
help='type of model')
parser.add_argument('--emb_model_type', type=str, default='LinEmb_LR')
parser.add_argument('--ot_eps', default=0.1, type=float,
help='Strength for entropic constraint regularization for OT')
parser.add_argument('--model_config', type=str,
default='ABMIL_default', help="name of model config file")
parser.add_argument('--n_fc_layers', type=int)
parser.add_argument('--em_iter', type=int)
parser.add_argument('--tau', type=float)
parser.add_argument('--out_type', type=str, default='param_cat')
# Prototype related
parser.add_argument('--load_proto', action='store_true', default=False)
parser.add_argument('--proto_path', type=str)
parser.add_argument('--fix_proto', action='store_true', default=False)
parser.add_argument('--n_proto', type=int)
parser.add_argument('--in_dim', default=768, type=int,
help='dim of input features')
parser.add_argument('--in_dropout', default=0.1, type=float,
help='Probability of dropping out input features.')
parser.add_argument('--bag_size', type=int, default=-1)
parser.add_argument('--train_bag_size', type=int, default=-1)
parser.add_argument('--val_bag_size', type=int, default=-1)
parser.add_argument('--loss_fn', type=str, default='nll', choices=['nll', 'cox', 'sumo', 'ipcwls', 'rank'],
help='which loss function to use')
parser.add_argument('--nll_alpha', type=float, default=0,
help='Balance between censored / uncensored loss')
# experiment task / label args ###
parser.add_argument('--exp_code', type=str, default=None,
help='experiment code for saving results')
parser.add_argument('--task', type=str, default='unspecified_survival_task')
parser.add_argument('--target_col', type=str, default='os_survival_days')
parser.add_argument('--n_label_bins', type=int, default=4,
help='number of bins for event time discretization')
# dataset / split args ###
parser.add_argument('--data_source', type=str, default=None,
help='manually specify the data source')
parser.add_argument('--split_dir', type=str, default=None,
help='manually specify the set of splits to use')
parser.add_argument('--split_names', type=str, default='train,val,test',
help='delimited list for specifying names within each split')
parser.add_argument('--overwrite', action='store_true', default=False,
help='overwrite existing results')
# logging args ###
parser.add_argument('--results_dir', default='./results',
help='results directory (default: ./results)')
parser.add_argument('--tags', nargs='+', type=str, default=None,
help='tags for logging')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
print('task: ', args.task)
args.split_dir = j_('splits', args.split_dir)
print('split_dir: ', args.split_dir)
split_num = args.split_dir.split('/')[2].split('_k=')
args.split_name_clean = args.split_dir.split('/')[2].split('_k=')[0]
if len(split_num) > 1:
args.split_k = int(split_num[1])
else:
args.split_k = 0
if args.load_proto:
assert os.path.isfile(args.proto_path), f"Path {args.proto_path} doesn't exist!"
args.proto_fname = '/'.join(args.proto_path.split('/')[-2:])
proto_fname_clean = '--'.join(args.proto_fname[:-4].split('/'))
### Allows you to pass in multiple data sources (separated by comma). If single data source, no change.
args.data_source = [src for src in args.data_source.split(',')]
check_params_same = []
for src in args.data_source:
### assert data source exists + extract feature name ###
print('data source: ', src)
assert os.path.isdir(src), f"data source must be a directory: {src} invalid"
### parse patching info ###
feat_name = os.path.basename(src)
mag, patch_size = extract_patching_info(os.path.dirname(src))
if (mag < 0 or patch_size < 0):
raise ValueError(f"invalid patching info parsed for {src}")
check_params_same.append([feat_name, mag, patch_size])
#### parse model name ####
parsed = parse_model_name(feat_name)
parsed.update({'patch_mag': mag, 'patch_size': patch_size})
try:
check_params_same = pd.DataFrame(check_params_same, columns=['feats_name', 'mag', 'patch_size'])
assert check_params_same.drop(['feats_name'],axis=1).drop_duplicates().shape[0] == 1
print("All data sources have the same feature extraction parameters.")
except:
print("Data sources do not share the same feature extraction parameters. Exiting...")
sys.exit()
### Updated parsed mdoel parameters in args.Namespace ###
for key, val in parsed.items():
setattr(args, key, val)
### setup results dir ###
if args.exp_code is None:
if args.model_config == 'PANTHER_default':
exp_code = f"{args.split_name_clean}::{args.model_config}+{args.emb_model_type}::{args.loss_fn}::{feat_name}"
else:
exp_code = f"{args.split_name_clean}::{args.model_config}::{feat_name}"
else:
pass
args.results_dir = j_(args.results_dir,
args.task,
f'k={args.split_k}',
str(exp_code),
str(exp_code)+f"::{get_current_time()}")
os.makedirs(args.results_dir, exist_ok=True)
print("\n################### Settings ###################")
for key, val in vars(args).items():
print("{}: {}".format(key, val))
with open(j_(args.results_dir, 'config.json'), 'w') as f:
f.write(json.dumps(vars(args), sort_keys=True, indent=4))
#### train ####
results = main(args)
print("FINISHED!\n\n\n")