[286bfb]: / src / training / main_classification.py

Download this file

314 lines (262 with data), 13.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
"""
Main entry point for classification downstream tasks
"""
from __future__ import print_function
import argparse
import pdb
import os
from os.path import join as j_
import sys
# internal imports
from utils.file_utils import save_pkl
from utils.utils import (seed_torch, array2list, merge_dict, read_splits,
parse_model_name, get_current_time,
extract_patching_info)
from .trainer import train
from wsi_datasets import WSIClassificationDataset
from data_factory import tasks, label_dicts
import torch
from torch.utils.data import DataLoader, sampler
import pandas as pd
import numpy as np
import json
PROTO_MODELS = ['PANTHER', 'OT', 'H2T', 'ProtoCount']
def build_sampler(dataset, sampler_type=None):
data_sampler = None
if sampler_type is None:
return data_sampler
assert sampler_type in ['weighted', 'random', 'sequential']
if sampler_type == 'weighted':
labels = dataset.get_labels(np.arange(len(dataset)), apply_transform=True)
uniques, counts = np.unique(labels, return_counts=True)
weights = {uniques[i]: 1. / counts[i] for i in range(len(uniques))}
samples_weight = np.array([weights[t] for t in labels])
samples_weight = torch.from_numpy(samples_weight)
data_sampler = sampler.WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
elif sampler_type == 'random':
data_sampler = sampler.RandomSampler(dataset)
elif sampler_type == 'sequential':
data_sampler = sampler.SequentialSampler(dataset)
return data_sampler
def build_datasets(csv_splits, model_type, batch_size=1, num_workers=2,
train_kwargs={}, val_kwargs={}, sampler_types={'train': 'random',
'val': 'sequential',
'test': 'sequential'}):
"""
Construct dataloaders from the data splits
"""
dataset_splits = {}
for k in csv_splits.keys(): # ['train', 'val', 'test']
print("\nSPLIT: ", k)
df = csv_splits[k]
dataset_kwargs = train_kwargs.copy() if (k == 'train') else val_kwargs.copy()
if k == 'test_nlst':
dataset_kwargs['sample_col'] = 'case_id'
dataset = WSIClassificationDataset(df, **dataset_kwargs)
data_sampler = build_sampler(dataset, sampler_type=sampler_types.get(k, 'sequential'))
# If prototype methods, each WSI will have same feature bag dimension and is batchable
# Otherwise, we need to use batch size of 1 to accommodate to different bag size for each WSI.
# Alternatively, we can sample same number of patch features per WSI to have larger batch.
if model_type not in PROTO_MODELS:
batch_size = batch_size if dataset_kwargs.get('bag_size', -1) > 0 else 1
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=data_sampler, num_workers=num_workers)
dataset_splits[k] = dataloader
print(f'split: {k}, n: {len(dataset)}')
return dataset_splits
def main(args):
if args.train_bag_size == -1:
args.train_bag_size = args.bag_size
if args.val_bag_size == -1:
args.val_bag_size = args.bag_size
sampler_types = {'train': args.train_sampler if args.model_type not in PROTO_MODELS else 'sequential',
'val': 'sequential',
'test': 'sequential'}
train_kwargs = dict(data_source=args.data_source,
label_map=args.label_map,
target_col=args.target_col,
bag_size=args.train_bag_size,
shuffle=True)
# use the whole bag at test time
val_kwargs = dict(data_source=args.data_source,
label_map=args.label_map,
target_col=args.target_col,
bag_size=args.val_bag_size)
all_results, all_dumps = {}, {}
# Cross-validation
seed_torch(args.seed)
csv_splits = read_splits(args)
print('successfully read splits for: ', list(csv_splits.keys()))
dataset_splits = build_datasets(csv_splits,
model_type=args.model_type,
batch_size=args.batch_size,
num_workers=args.num_workers,
sampler_types=sampler_types,
train_kwargs=train_kwargs,
val_kwargs=val_kwargs)
fold_results, fold_dumps = train(dataset_splits, args, mode='classification')
for split, split_results in fold_results.items():
all_results[split] = merge_dict({}, split_results) if (split not in all_results.keys()) else merge_dict(all_results[split], split_results)
save_pkl(j_(args.results_dir, f'{split}_results.pkl'), fold_dumps[split]) # saves per-split, per-fold results to pkl
final_dict = {}
for split, split_results in all_results.items():
final_dict.update({f'{metric}_{split}': array2list(val) for metric, val in split_results.items()})
final_df = pd.DataFrame(final_dict)
save_name = 'summary.csv'
final_df.to_csv(j_(args.results_dir, save_name), index=False)
with open(j_(args.results_dir, save_name + '.json'), 'w') as f:
f.write(json.dumps(final_dict, sort_keys=True, indent=4))
dump_path = j_(args.results_dir, 'all_dumps.h5')
fold_dumps.update({'labels': np.array(list(args.label_map.keys()), dtype=np.object_)})
save_pkl(dump_path, fold_dumps)
return final_dict
# Generic training settings
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
### optimizer settings ###
parser.add_argument('--max_epochs', type=int, default=20,
help='maximum number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--wd', type=float, default=1e-5,
help='weight decay')
parser.add_argument('--accum_steps', type=int, default=1,
help='grad accumulation steps')
parser.add_argument('--opt', type=str,
choices=['adamW', 'sgd'], default='adamW')
parser.add_argument('--lr_scheduler', type=str,
choices=['cosine', 'linear', 'constant'], default='constant')
parser.add_argument('--warmup_steps', type=int,
default=-1, help='warmup iterations')
parser.add_argument('--warmup_epochs', type=int,
default=-1, help='warmup epochs')
parser.add_argument('--batch_size', type=int, default=1)
### misc ###
parser.add_argument('--print_every', default=100,
type=int, help='how often to print')
parser.add_argument('--seed', type=int, default=1,
help='random seed for reproducible experiment (default: 1)')
parser.add_argument('--num_workers', type=int, default=2)
### Earlystopper args ###
parser.add_argument('--early_stopping', action='store_true',
default=False, help='enable early stopping')
parser.add_argument('--es_min_epochs', type=int, default=15,
help='early stopping min epochs')
parser.add_argument('--es_patience', type=int, default=10,
help='early stopping min patience')
parser.add_argument('--es_metric', type=str, default='loss',
help='early stopping metric')
##
# model / loss fn args ###
parser.add_argument('--model_type', type=str, choices=['H2T', 'ABMIL', 'TransMIL', 'SumMIL', 'OT', 'PANTHER', 'ProtoCount', 'DeepAttnMIL', 'ILRA'],
default='ABMIL',
help='type of model')
parser.add_argument('--emb_model_type', type=str, default='LinEmb_LR')
parser.add_argument('--ot_eps', default=0.1, type=float,
help='Strength for entropic constraint regularization for OT')
parser.add_argument('--model_config', type=str,
default='ABMIL_default', help="name of model config file")
parser.add_argument('--in_dim', default=768, type=int,
help='dim of input features')
parser.add_argument('--in_dropout', default=0.0, type=float,
help='Probability of dropping out input features.')
parser.add_argument('--bag_size', type=int, default=-1)
parser.add_argument('--train_bag_size', type=int, default=-1)
parser.add_argument('--val_bag_size', type=int, default=-1)
parser.add_argument('--train_sampler', type=str, default='random',
choices=['random', 'weighted', 'sequential'])
parser.add_argument('--n_fc_layers', type=int)
parser.add_argument('--em_iter', type=int)
parser.add_argument('--tau', type=float)
parser.add_argument('--out_type', type=str, default='param_cat')
# Prototype related
parser.add_argument('--load_proto', action='store_true', default=False)
parser.add_argument('--proto_path', type=str, default='.')
parser.add_argument('--fix_proto', action='store_true', default=False)
parser.add_argument('--n_proto', type=int)
# experiment task / label args ###
parser.add_argument('--exp_code', type=str,
help='experiment code for saving results')
parser.add_argument('--task', type=str, choices=tasks)
parser.add_argument('--target_col', type=str, default='label')
# dataset / split args ###
parser.add_argument('--data_source', type=str, default=None,
help='manually specify the data source')
parser.add_argument('--split_dir', type=str, default=None,
help='manually specify the set of splits to use')
parser.add_argument('--split_names', type=str, default='train,val,test',
help='delimited list for specifying names within each split')
parser.add_argument('--overwrite', action='store_true', default=False,
help='overwrite existing results')
# logging args ###
parser.add_argument('--results_dir', default='./results',
help='results directory (default: ./results)')
parser.add_argument('--tags', nargs='+', type=str, default=None,
help='tags for logging')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
args.label_map = label_dicts[args.task]
print('label map: ', args.label_map)
args.n_classes = len(set(list(args.label_map.values())))
print('task: ', args.task)
args.split_dir = j_('splits', args.split_dir)
print('split_dir: ', args.split_dir)
split_num = args.split_dir.split('/')[2].split('_k=')
args.split_name_clean = args.split_dir.split('/')[2].split('_k=')[0]
if len(split_num) > 1:
args.split_k = int(split_num[1])
else:
args.split_k = 0
print(args.proto_path)
if os.path.isfile(args.proto_path):
args.proto_fname = '/'.join(args.proto_path.split('/')[-2:])
### Allows you to pass in multiple data sources (separated by comma). If single data source, no change.
args.data_source = [src for src in args.data_source.split(',')]
check_params_same = []
for src in args.data_source:
### assert data source exists + extract feature name ###
print('data source: ', src)
assert os.path.isdir(src), f"data source must be a directory: {src} invalid"
### parse patching info ###
feat_name = os.path.basename(src)
mag, patch_size = extract_patching_info(os.path.dirname(src))
if (mag < 0 or patch_size < 0):
raise ValueError(f"invalid patching info parsed for {src}")
check_params_same.append([feat_name, mag, patch_size])
try:
check_params_same = pd.DataFrame(check_params_same, columns=['feats_name', 'mag', 'patch_size'])
print(check_params_same.to_string())
assert check_params_same.drop(['feats_name'],axis=1).drop_duplicates().shape[0] == 1
print("All data sources have the same feature extraction parameters.")
except:
print("Data sources do not share the same feature extraction parameters. Exiting...")
sys.exit()
### Updated parsed mdoel parameters in args.Namespace ###
#### parse patching info ####
mag, patch_size = extract_patching_info(os.path.dirname(args.data_source[0]))
#### parse model name ####
parsed = parse_model_name(feat_name)
parsed.update({'patch_mag': mag, 'patch_size': patch_size, 'feat_names': sorted(list(set(check_params_same['feats_name'].tolist())))})
for key, val in parsed.items():
setattr(args, key, val)
### setup results dir ###
if args.exp_code is None:
if args.model_config == 'PANTHER_default':
exp_code = f"{args.split_name_clean}::{args.model_config}+{args.emb_model_type}::{args.loss_fn}::{feat_name}"
else:
exp_code = f"{args.split_name_clean}::{args.model_config}::{feat_name}"
else:
pass
args.results_dir = j_(args.results_dir,
args.task,
f'k={args.split_k}',
str(exp_code),
str(exp_code)+f"::{get_current_time()}")
os.makedirs(args.results_dir, exist_ok=True)
print("\n################### Settings ###################")
for key, val in vars(args).items():
print("{}: {}".format(key, val))
with open(j_(args.results_dir, 'config.json'), 'w') as f:
f.write(json.dumps(vars(args), sort_keys=True, indent=4))
#### train ####
results = main(args)
print("FINISHED!\n\n\n")