Diff of /exec.py [000000] .. [bb7f56]

Switch to unified view

a b/exec.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
17
"""execution script."""
18
19
import code
20
import argparse
21
import os, warnings
22
import time
23
import pandas as pd
24
import pickle
25
import sys
26
import numpy as np
27
28
import torch
29
30
import utils.exp_utils as utils
31
from evaluator import Evaluator
32
from predictor import Predictor
33
from plotting import plot_batch_prediction
34
from datetime import datetime
35
36
for msg in ["Attempting to set identical bottom==top results",
37
            "This figure includes Axes that are not compatible with tight_layout",
38
            "Data has no positive values, and therefore cannot be log-scaled.",
39
            ".*invalid value encountered in double_scalars.*",
40
            ".*Mean of empty slice.*"]:
41
    warnings.filterwarnings("ignore", msg)
42
43
44
def train(logger):
45
    """
46
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
47
    specified in the configs.
48
    """
49
    time_start_train = time.time()
50
    logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
51
        cf.dim, cf.fold, cf.exp_dir, cf.model))
52
53
    net = model.net(cf, logger).cuda()
54
    if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam":
55
        logger.info("Using Adam optimizer.")
56
        optimizer = torch.optim.Adam(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay,
57
                                                                   exclude_from_wd=cf.exclude_from_wd),
58
                                      lr=cf.learning_rate[0])
59
    else:
60
        logger.info("Using AdamW optimizer.")
61
        optimizer = torch.optim.AdamW(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay,
62
                                                                   exclude_from_wd=cf.exclude_from_wd),
63
                                      lr=cf.learning_rate[0])
64
65
66
    if cf.dynamic_lr_scheduling:
67
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor,
68
                                                               patience=cf.scheduling_patience)
69
70
    model_selector = utils.ModelSelector(cf, logger)
71
    train_evaluator = Evaluator(cf, logger, mode='train')
72
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
73
74
    starting_epoch = 1
75
76
    # prepare monitoring
77
    monitor_metrics = utils.prepare_monitoring(cf)
78
79
    if cf.resume:
80
        checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint")
81
        starting_epoch, net, optimizer, monitor_metrics = \
82
            utils.load_checkpoint(checkpoint_path, net, optimizer)
83
        logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch))
84
85
    ####### Use this to create hdf5
86
    logger.info('loading dataset and initializing batch generators...')
87
    print ("Starting data_loader.get_train_generators in exec...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
88
    batch_gen = data_loader.get_train_generators(cf, logger)
89
    print ("Finished data_loader.get_train_generators in exec...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
90
91
    ####### Writing out train data to file
92
    #train_data = dict()
93
    #print ('Write training data to json')
94
    #for bix in range(cf.num_train_batches):
95
    #     batch = next(batch_gen['train'])
96
    #     train_data.update(batch)
97
    #with open('train_data.json', 'w') as outfile:
98
    #    json.dump(train_data, outfile)
99
    #####################################
100
101
    for epoch in range(starting_epoch, cf.num_epochs + 1):
102
103
        logger.info('starting training epoch {}'.format(epoch))
104
        start_time = time.time()
105
        net.train()
106
        train_results_list = []
107
        for bix in range(cf.num_train_batches):
108
            ######### Insert call to grab right training data fold from hdf5
109
            print ("Get next batch_gen['train] ...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
110
            ##Stalled
111
            batch = next(batch_gen['train']) ######## Instead of this line, grab a batch from training data fold
112
            tic_fw = time.time()
113
            print ("Start forward pass...",time.time())
114
            results_dict = net.train_forward(batch)
115
            tic_bw = time.time()
116
            optimizer.zero_grad()
117
            print ("Start backward pass..",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
118
            results_dict['torch_loss'].backward()
119
            print ("Start optimizing...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
120
            optimizer.step()
121
            print('\rtr. batch {0}/{1} (ep. {2}) fw {3:.2f}s / bw {4:.2f} s / total {5:.2f} s || '.format(
122
                bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw,
123
                time.time() - tic_fw) + results_dict['logger_string'], flush=True, end="")
124
            print ("Results Dict Size: ",sys.getsizeof(results_dict))
125
            train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"]))
126
            print("Loop through train batch DONE",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"),(time.time()-time_start_train)/60, "minutes since training started")
127
128
        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
129
130
        logger.info('generating training example plot.')
131
        utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(
132
           cf.plot_dir, 'pred_example_{}_train.png'.format(cf.fold)))
133
134
        train_time = time.time() - start_time
135
136
        logger.info('starting validation in mode {}.'.format(cf.val_mode))
137
        with torch.no_grad():
138
            net.eval()
139
            if cf.do_validation:
140
                val_results_list = []
141
                val_predictor = Predictor(cf, net, logger, mode='val')
142
                for _ in range(batch_gen['n_val']):
143
                    ########## Insert call to grab right validation data fold from hdf5
144
                    batch = next(batch_gen[cf.val_mode])
145
                    if cf.val_mode == 'val_patient':
146
                        results_dict = val_predictor.predict_patient(batch)
147
                    elif cf.val_mode == 'val_sampling':
148
                        results_dict = net.train_forward(batch, is_validation=True)
149
                    #val_results_list.append([results_dict['boxes'], batch['pid']])
150
                    val_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"]))
151
152
                _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
153
                model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
154
155
            # update monitoring and prediction plots
156
            monitor_metrics.update({"lr":
157
                                        {str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
158
            logger.metrics2tboard(monitor_metrics, global_step=epoch)
159
160
            epoch_time = time.time() - start_time
161
            logger.info('trained epoch {}: took {} ({} train / {} val)'.format(
162
                epoch, utils.get_formatted_duration(epoch_time, "ms"), utils.get_formatted_duration(train_time, "ms"),
163
                utils.get_formatted_duration(epoch_time-train_time, "ms")))
164
            ########### Insert call to grab right validation data fold from hdf5
165
            batch = next(batch_gen['val_sampling'])
166
            results_dict = net.train_forward(batch, is_validation=True)
167
            logger.info('generating validation-sampling example plot.')
168
            utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(
169
                cf.plot_dir, 'pred_example_{}_val.png'.format(cf.fold)))
170
171
        # -------------- scheduling -----------------
172
        if cf.dynamic_lr_scheduling:
173
            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
174
        else:
175
            for param_group in optimizer.param_groups:
176
                param_group['lr'] = cf.learning_rate[epoch-1]
177
178
def test(logger):
179
    """
180
    perform testing for a given fold (or hold out set). save stats in evaluator.
181
    """
182
    test_ix = np.arange((len(os.listdir(cf.pp_test_data_path))/3)-2,dtype=np.int16)
183
    print ("Test indices size: "+str(len(test_ix))+" and is held out test set "+str(cf.hold_out_test_set))
184
    
185
    logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
186
    net = model.net(cf, logger).cuda()
187
    test_predictor = Predictor(cf, net, logger, mode='test')
188
    test_evaluator = Evaluator(cf, logger, mode='test')
189
    ################ Insert call to grab right test data (fold?) from hdf5
190
    batch_gen = data_loader.get_test_generator(cf, logger)
191
    ####code.interact(local=locals())
192
    test_results_list = test_predictor.predict_test_set(batch_gen, return_results=True)
193
    test_evaluator.evaluate_predictions(test_results_list)
194
    test_evaluator.score_test_df()
195
196
if __name__ == '__main__':
197
    stime = time.time()
198
199
    parser = argparse.ArgumentParser()
200
    parser.add_argument('-m', '--mode', type=str,  default='train_test',
201
                        help='one out of: train / test / train_test / analysis / create_exp')
202
    parser.add_argument('-f','--folds', nargs='+', type=int, default=None,
203
                        help='None runs over all folds in CV. otherwise specify list of folds.')
204
    parser.add_argument('--exp_dir', type=str, default='/path/to/experiment/directory',
205
                        help='path to experiment dir. will be created if non existent.')
206
    parser.add_argument('--server_env', default=False, action='store_true',
207
                        help='change IO settings to deploy models on a cluster.')
208
    parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config.")
209
    parser.add_argument('--use_stored_settings', default=False, action='store_true',
210
                        help='load configs from existing exp_dir instead of source dir. always done for testing, '
211
                             'but can be set to true to do the same for training. useful in job scheduler environment, '
212
                             'where source code might change before the job actually runs.')
213
    parser.add_argument('--resume', action="store_true", default=False,
214
                        help='if given, resume from checkpoint(s) of the specified folds.')
215
    parser.add_argument('--exp_source', type=str, default='experiments/toy_exp',
216
                        help='specifies, from which source experiment to load configs and data_loader.')
217
    parser.add_argument('--no_benchmark', action='store_true', help="Do not use cudnn.benchmark.")
218
    parser.add_argument('--cuda_device', type=int, default=0, help="Index of CUDA device to use.")
219
    parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
220
221
    args = parser.parse_args()
222
    folds = args.folds
223
224
    torch.backends.cudnn.benchmark = not args.no_benchmark
225
226
    ########### Creating hdf5 
227
    #if args.mode = 'create_hdf5':
228
    #    if folds is None:
229
    #       folds = range(cf.n_cv_splits)
230
231
    #    for fold in folds:
232
    #       create_hdf_foldwise_with_batch_generator_for_train/val/test
233
234
235
    if args.mode == 'train' or args.mode == 'train_test':
236
237
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, args.use_stored_settings)
238
        if args.dev:
239
            folds = [0,1]
240
            cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 2
241
            cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
242
            cf.test_n_epochs =  cf.save_n_models
243
            cf.max_test_patients = 2
244
245
        cf.data_dest = args.data_dest
246
        logger = utils.get_logger(cf.exp_dir, cf.server_env)
247
        logger.info("cudnn benchmark: {}, deterministic: {}.".format(torch.backends.cudnn.benchmark,
248
                                                                     torch.backends.cudnn.deterministic))
249
        logger.info("sending tensors to CUDA device: {}.".format(torch.cuda.get_device_name(args.cuda_device)))
250
        data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
251
        model = utils.import_module('model', cf.model_path)
252
        logger.info("loaded model from {}".format(cf.model_path))
253
        if folds is None:
254
            folds = range(cf.n_cv_splits)
255
256
        with torch.cuda.device(args.cuda_device):
257
            for fold in folds:
258
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
259
                cf.fold = fold
260
                cf.resume = args.resume
261
                if not os.path.exists(cf.fold_dir):
262
                    os.mkdir(cf.fold_dir)
263
                logger.set_logfile(fold=fold)
264
                train(logger)
265
                cf.resume = False
266
                if args.mode == 'train_test':
267
                    test(logger)
268
                
269
                #Concatenate test results by detection
270
271
            if cf.hold_out_test_set == False:
272
                  test_frames = [pd.read_pickle(os.path.join(cf.test_dir,f)) for f  in os.listdir(cf.test_dir) if '_test_df.pickle' in f]
273
                  all_preds = pd.concat(test_frames)
274
                  all_preds.to_csv(os.path.join(cf.test_dir,"all_folds_test.csv"))
275
276
                #Concatenate detection raw boxes across folds
277
                  det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'raw_pred_boxes_list.pickle')) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
278
                  all_dets=list()
279
                  for i in det_frames:
280
                    all_dets.extend(i)
281
                  with open(os.path.join(cf.exp_dir, 'all_raw_dets.pickle'), 'wb') as handle:
282
                    pickle.dump(all_dets, handle)
283
284
                #Concatenate detection wbc boxes across folds
285
                  det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'wbc_pred_boxes_list.pickle')) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
286
                  all_dets=list()
287
                  for i in det_frames:
288
                    all_dets.extend(i)
289
                  with open(os.path.join(cf.exp_dir, 'all_wbc_dets.pickle'), 'wb') as handle:
290
                    pickle.dump(all_dets, handle)
291
292
    elif args.mode == 'test':
293
294
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
295
        if args.dev:
296
            folds = [0,1]
297
            cf.test_n_epochs = 2; cf.max_test_patients = 2
298
299
        cf.data_dest = args.data_dest
300
        logger = utils.get_logger(cf.exp_dir, cf.server_env)
301
        data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
302
        model = utils.import_module('model', cf.model_path)
303
        logger.info("loaded model from {}".format(cf.model_path))
304
        if folds is None:
305
            folds = range(cf.n_cv_splits)
306
307
        with torch.cuda.device(args.cuda_device):
308
            for fold in folds:
309
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
310
                cf.fold = fold
311
                logger.set_logfile(fold=fold)
312
                test(logger)
313
314
            if cf.hold_out_test_set == True:
315
                 test_mode = "_hold_out"
316
            else:
317
                test_mode = None
318
            test_frames = [pd.read_pickle(os.path.join(cf.test_dir,f)) for f  in os.listdir(cf.test_dir) if '_test_df.pickle' in f] 
319
            all_preds = pd.concat(test_frames)
320
            all_preds.to_csv(os.path.join(cf.test_dir,"all_folds{}_test.csv".format(test_mode)))
321
322
        #Concatenate detection raw boxes across folds
323
            det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'raw_pred_boxes{}_list.pickle'.format(test_mode))) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
324
            all_dets=list()
325
            for i in det_frames:
326
              all_dets.extend(i)
327
            with open(os.path.join(cf.exp_dir, 'all_raw{}_dets.pickle'.format(test_mode)), 'wb') as handle:
328
              pickle.dump(all_dets, handle)
329
330
        #Concatenate detection wbc boxes across folds
331
            det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'wbc_pred_boxes{}_list.pickle'.format(test_mode))) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
332
            all_dets=list()
333
            for i in det_frames:
334
              all_dets.extend(i)
335
            with open(os.path.join(cf.exp_dir, 'all_wbc{}_dets.pickle'.format(test_mode)), 'wb') as handle:
336
              pickle.dump(all_dets, handle)
337
338
    # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
339
    elif args.mode == 'analysis':
340
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
341
        logger = utils.get_logger(cf.exp_dir, cf.server_env)
342
343
        if args.dev:
344
            cf.test_n_epochs = 2
345
346
        if cf.hold_out_test_set and cf.ensemble_folds:
347
            # create and save (unevaluated) predictions across all folds
348
            predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
349
            results_list = predictor.load_saved_predictions(apply_wbc=True)
350
            utils.create_csv_output([(res_dict["boxes"], pid) for res_dict, pid in results_list], cf, logger)
351
            logger.info('starting evaluation...')
352
            cf.fold = 'overall_hold_out'
353
            evaluator = Evaluator(cf, logger, mode='test')
354
            evaluator.evaluate_predictions(results_list)
355
            evaluator.score_test_df()
356
357
        else:
358
            fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
359
                         os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
360
            if folds is None:
361
                folds = range(cf.n_cv_splits)
362
            for fold in folds:
363
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
364
                cf.fold = fold
365
                logger.set_logfile(fold=fold)
366
                if cf.fold_dir in fold_dirs:
367
                    predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
368
                    results_list = predictor.load_saved_predictions(apply_wbc=True)
369
                    logger.info('starting evaluation...')
370
                    evaluator = Evaluator(cf, logger, mode='test')
371
                    evaluator.evaluate_predictions(results_list)
372
                    evaluator.score_test_df()
373
                else:
374
                    logger.info("Skipping fold {} since no model parameters found.".format(fold))
375
376
    # create experiment folder and copy scripts without starting job.
377
    # useful for cloud deployment where configs might change before job actually runs.
378
    elif args.mode == 'create_exp':
379
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, use_stored_settings=False)
380
        logger = utils.get_logger(cf.exp_dir)
381
        logger.info('created experiment directory at {}'.format(cf.exp_dir))
382
383
    else:
384
        raise RuntimeError('mode specified in args is not implemented...')
385
386
387
    t = utils.get_formatted_duration(time.time() - stime)
388
    logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t))
389
    del logger