Diff of /exec_dp.py [000000] .. [bb7f56]

Switch to unified view

a b/exec_dp.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
17
"""execution script."""
18
19
import code
20
import argparse
21
import os, warnings
22
import time
23
import pandas as pd
24
import pickle
25
import sys
26
import cProfile, pstats
27
28
import torch
29
import torch.nn as nn
30
31
import utils.exp_utils as utils
32
from evaluator import Evaluator
33
from predictor import Predictor
34
from plotting import plot_batch_prediction
35
from datetime import datetime
36
37
for msg in ["Attempting to set identical bottom==top results",
38
            "This figure includes Axes that are not compatible with tight_layout",
39
            "Data has no positive values, and therefore cannot be log-scaled.",
40
            ".*invalid value encountered in double_scalars.*",
41
            ".*Mean of empty slice.*"]:
42
    warnings.filterwarnings("ignore", msg)
43
44
45
def train(logger):
46
    """
47
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
48
    specified in the configs.
49
    """
50
    time_start_train = time.time()
51
    logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
52
        cf.dim, cf.fold, cf.exp_dir, cf.model))
53
54
    print ("Number of cuda devices available ",torch.cuda.device_count())
55
56
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
57
58
    net = model.net(cf, logger).cuda()
59
60
    #net = nn.DataParallel(net).to(device)
61
62
    print ("Did data parallel get carried out for net in exec script?? ",isinstance(net, nn.DataParallel))
63
64
    if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam":
65
        logger.info("Using Adam optimizer.")
66
        optimizer = torch.optim.Adam(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay,
67
                                                                   exclude_from_wd=cf.exclude_from_wd),
68
                                      lr=cf.learning_rate[0])
69
    else:
70
        logger.info("Using AdamW optimizer.")
71
        optimizer = torch.optim.AdamW(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay,
72
                                                                   exclude_from_wd=cf.exclude_from_wd),
73
                                      lr=cf.learning_rate[0])
74
75
76
    if cf.dynamic_lr_scheduling:
77
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor,
78
                                                               patience=cf.scheduling_patience)
79
80
    model_selector = utils.ModelSelector(cf, logger)
81
    train_evaluator = Evaluator(cf, logger, mode='train')
82
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
83
84
    starting_epoch = 1
85
86
    # prepare monitoring
87
    monitor_metrics = utils.prepare_monitoring(cf)
88
89
    if cf.resume:
90
        checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint")
91
        starting_epoch, net, optimizer, monitor_metrics = \
92
            utils.load_checkpoint(checkpoint_path, net, optimizer)
93
        logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch))
94
95
    ####### Use this to create hdf5
96
    logger.info('loading dataset and initializing batch generators...')
97
    print ("Starting data_loader.get_train_generators in exec...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
98
    batch_gen = data_loader.get_train_generators(cf, logger)
99
    print ("Finished data_loader.get_train_generators in exec...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
100
101
    ####### Writing out train data to file
102
    #train_data = dict()
103
    #print ('Write training data to json')
104
    #for bix in range(cf.num_train_batches):
105
    #     batch = next(batch_gen['train'])
106
    #     train_data.update(batch)
107
    #with open('train_data.json', 'w') as outfile:
108
    #    json.dump(train_data, outfile)
109
    #####################################
110
111
    for epoch in range(starting_epoch, cf.num_epochs + 1):
112
113
        logger.info('starting training epoch {}'.format(epoch))
114
        start_time = time.time()
115
        net.train()
116
        train_results_list = []
117
        for bix in range(cf.num_train_batches):
118
119
            # profiler = cProfile.Profile()
120
            # profiler.enable()
121
122
123
            ######### Insert call to grab right training data fold from hdf5
124
            print ("Get next batch_gen['train] ...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
125
            ##Stalled
126
            batch = next(batch_gen['train']) ######## Instead of this line, grab a batch from training data fold
127
            tic_fw = time.time()
128
            print ("Start forward pass...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
129
            results_dict = net.train_forward(batch)
130
            tic_bw = time.time()
131
            optimizer.zero_grad()
132
            print ("Start backward pass..",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
133
            results_dict['torch_loss'].backward()
134
            print ("Start optimizing...",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"))
135
            optimizer.step()
136
            print('\rtr. batch {0}/{1} (ep. {2}) fw {3:.2f}s / bw {4:.2f} s / total {5:.2f} s || '.format(
137
                bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw,
138
                time.time() - tic_fw) + results_dict['logger_string'], flush=True, end="")
139
            print ("Results Dict Size: ",sys.getsizeof(results_dict))
140
            train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"]))
141
            print("Loop through train batch DONE",datetime.now().strftime("%m/%d/%Y %H:%M:%S:%f"),(time.time()-time_start_train)/60, "minutes since training started")
142
143
            # profiler.disable()
144
            # stats = pstats.Stats(profiler).sort_stats('cumtime')
145
            # stats.print_stats()
146
147
        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
148
149
        # logger.info('generating training example plot.')
150
        # utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(
151
        #    cf.plot_dir, 'pred_example_{}_train.png'.format(cf.fold)))
152
153
        train_time = time.time() - start_time
154
155
        logger.info('starting validation in mode {}.'.format(cf.val_mode))
156
        with torch.no_grad():
157
            net.eval()
158
            if cf.do_validation:
159
                val_results_list = []
160
                val_predictor = Predictor(cf, net, logger, mode='val')
161
                for _ in range(batch_gen['n_val']):
162
                    ########## Insert call to grab right validation data fold from hdf5
163
                    batch = next(batch_gen[cf.val_mode])
164
                    if cf.val_mode == 'val_patient':
165
                        results_dict = val_predictor.predict_patient(batch)
166
                    elif cf.val_mode == 'val_sampling':
167
                        results_dict = net.train_forward(batch, is_validation=True)
168
                    #val_results_list.append([results_dict['boxes'], batch['pid']])
169
                    val_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"]))
170
                    #monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values'])
171
172
                _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
173
                model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
174
175
            # update monitoring and prediction plots
176
            monitor_metrics.update({"lr":
177
                                        {str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
178
            logger.metrics2tboard(monitor_metrics, global_step=epoch)
179
180
            epoch_time = time.time() - start_time
181
            logger.info('trained epoch {}: took {} ({} train / {} val)'.format(
182
                epoch, utils.get_formatted_duration(epoch_time, "ms"), utils.get_formatted_duration(train_time, "ms"),
183
                utils.get_formatted_duration(epoch_time-train_time, "ms")))
184
            ########### Insert call to grab right validation data fold from hdf5
185
            batch = next(batch_gen['val_sampling'])
186
            results_dict = net.train_forward(batch, is_validation=True)
187
            logger.info('generating validation-sampling example plot.')
188
            utils.split_off_process(plot_batch_prediction, batch, results_dict, cf, outfile=os.path.join(
189
                cf.plot_dir, 'pred_example_{}_val.png'.format(cf.fold)))
190
191
        # -------------- scheduling -----------------
192
        if cf.dynamic_lr_scheduling:
193
            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
194
        else:
195
            for param_group in optimizer.param_groups:
196
                param_group['lr'] = cf.learning_rate[epoch-1]
197
198
def test(logger):
199
    """
200
    perform testing for a given fold (or hold out set). save stats in evaluator.
201
    """
202
203
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
204
205
    logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
206
    net = model.net(cf, logger).cuda()
207
208
    #net = nn.DataParallel(net).to(device)
209
210
    test_predictor = Predictor(cf, net, logger, mode='test')
211
    test_evaluator = Evaluator(cf, logger, mode='test')
212
    ################ Insert call to grab right test data (fold?) from hdf5
213
    batch_gen = data_loader.get_test_generator(cf, logger)
214
    ####code.interact(local=locals())
215
    test_results_list = test_predictor.predict_test_set(batch_gen, return_results=True)
216
    test_evaluator.evaluate_predictions(test_results_list)
217
    test_evaluator.score_test_df()
218
219
if __name__ == '__main__':
220
    stime = time.time()
221
222
    parser = argparse.ArgumentParser()
223
    parser.add_argument('-m', '--mode', type=str,  default='train_test',
224
                        help='one out of: train / test / train_test / analysis / create_exp')
225
    parser.add_argument('-f','--folds', nargs='+', type=int, default=None,
226
                        help='None runs over all folds in CV. otherwise specify list of folds.')
227
    parser.add_argument('--exp_dir', type=str, default='/path/to/experiment/directory',
228
                        help='path to experiment dir. will be created if non existent.')
229
    parser.add_argument('--server_env', default=False, action='store_true',
230
                        help='change IO settings to deploy models on a cluster.')
231
    parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config.")
232
    parser.add_argument('--use_stored_settings', default=False, action='store_true',
233
                        help='load configs from existing exp_dir instead of source dir. always done for testing, '
234
                             'but can be set to true to do the same for training. useful in job scheduler environment, '
235
                             'where source code might change before the job actually runs.')
236
    parser.add_argument('--resume', action="store_true", default=False,
237
                        help='if given, resume from checkpoint(s) of the specified folds.')
238
    parser.add_argument('--exp_source', type=str, default='experiments/toy_exp',
239
                        help='specifies, from which source experiment to load configs and data_loader.')
240
    parser.add_argument('--no_benchmark', action='store_true', help="Do not use cudnn.benchmark.")
241
    parser.add_argument('--cuda_device', type=int, default=0, help="Index of CUDA device to use.")
242
    parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
243
244
    args = parser.parse_args()
245
    folds = args.folds
246
247
    torch.backends.cudnn.benchmark = not args.no_benchmark
248
249
    ########### Creating hdf5 
250
    #if args.mode = 'create_hdf5':
251
    #    if folds is None:
252
    #       folds = range(cf.n_cv_splits)
253
254
    #    for fold in folds:
255
    #       create_hdf_foldwise_with_batch_generator_for_train/val/test
256
257
258
    if args.mode == 'train' or args.mode == 'train_test':
259
260
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, args.use_stored_settings)
261
        if args.dev:
262
            folds = [0,1]
263
            cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 2
264
            cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
265
            cf.test_n_epochs =  cf.save_n_models
266
            cf.max_test_patients = 2
267
268
        cf.data_dest = args.data_dest
269
        logger = utils.get_logger(cf.exp_dir, cf.server_env)
270
        logger.info("cudnn benchmark: {}, deterministic: {}.".format(torch.backends.cudnn.benchmark,
271
                                                                     torch.backends.cudnn.deterministic))
272
        logger.info("sending tensors to CUDA device: {}.".format(torch.cuda.get_device_name(args.cuda_device)))
273
        data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
274
        model = utils.import_module('model', cf.model_path)
275
        logger.info("loaded model from {}".format(cf.model_path))
276
        if folds is None:
277
            folds = range(cf.n_cv_splits)
278
279
        with torch.cuda.device(args.cuda_device):
280
            for fold in folds:
281
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
282
                cf.fold = fold
283
                cf.resume = args.resume
284
                if not os.path.exists(cf.fold_dir):
285
                    os.mkdir(cf.fold_dir)
286
                logger.set_logfile(fold=fold)
287
                train(logger)
288
289
                cf.resume = False
290
                if args.mode == 'train_test':
291
                    test(logger)
292
            
293
            #Concatenate test results by detection
294
295
            if cf.hold_out_test_set == False:
296
                    test_frames = [pd.read_pickle(os.path.join(cf.test_dir,f)) for f  in os.listdir(cf.test_dir) if '_test_df.pickle' in f]
297
                    all_preds = pd.concat(test_frames)
298
                    all_preds.to_csv(os.path.join(cf.test_dir,"all_folds_test.csv"))
299
300
                #Concatenate detection raw boxes across folds
301
                    det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'raw_pred_boxes_list.pickle')) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
302
                    all_dets=list()
303
                    for i in det_frames:
304
                        all_dets.extend(i)
305
                    with open(os.path.join(cf.exp_dir, 'all_raw_dets.pickle'), 'wb') as handle:
306
                        pickle.dump(all_dets, handle)
307
308
                #Concatenate detection wbc boxes across folds
309
                    det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'wbc_pred_boxes_list.pickle')) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
310
                    all_dets=list()
311
                    for i in det_frames:
312
                        all_dets.extend(i)
313
                    with open(os.path.join(cf.exp_dir, 'all_wbc_dets.pickle'), 'wb') as handle:
314
                        pickle.dump(all_dets, handle)
315
316
    elif args.mode == 'test':
317
318
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
319
        if args.dev:
320
            folds = [0,1]
321
            cf.test_n_epochs = 2; cf.max_test_patients = 2
322
323
        cf.data_dest = args.data_dest
324
        logger = utils.get_logger(cf.exp_dir, cf.server_env)
325
        data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
326
        model = utils.import_module('model', cf.model_path)
327
        logger.info("loaded model from {}".format(cf.model_path))
328
        if folds is None:
329
            folds = range(cf.n_cv_splits)
330
331
        with torch.cuda.device(args.cuda_device):
332
            for fold in folds:
333
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
334
                cf.fold = fold
335
                logger.set_logfile(fold=fold)
336
                test(logger)
337
338
            if cf.hold_out_test_set == False:
339
                  test_frames = [pd.read_pickle(os.path.join(cf.test_dir,f)) for f  in os.listdir(cf.test_dir) if '_test_df.pickle' in f] 
340
                  all_preds = pd.concat(test_frames)
341
                  all_preds.to_csv(os.path.join(cf.test_dir,"all_folds_test.csv"))
342
343
                #Concatenate detection raw boxes across folds
344
                  det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'raw_pred_boxes_list.pickle')) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
345
                  all_dets=list()
346
                  for i in det_frames:
347
                    all_dets.extend(i)
348
                  with open(os.path.join(cf.exp_dir, 'all_raw_dets.pickle'), 'wb') as handle:
349
                    pickle.dump(all_dets, handle)
350
351
                #Concatenate detection wbc boxes across folds
352
                  det_frames = [pd.read_pickle(os.path.join(cf.exp_dir,f,'wbc_pred_boxes_list.pickle')) for f in os.listdir(cf.exp_dir) if 'fold_' in f]
353
                  all_dets=list()
354
                  for i in det_frames:
355
                    all_dets.extend(i)
356
                  with open(os.path.join(cf.exp_dir, 'all_wbc_dets.pickle'), 'wb') as handle:
357
                    pickle.dump(all_dets, handle)
358
359
    # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
360
    elif args.mode == 'analysis':
361
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
362
        logger = utils.get_logger(cf.exp_dir, cf.server_env)
363
364
        if args.dev:
365
            cf.test_n_epochs = 2
366
367
        if cf.hold_out_test_set and cf.ensemble_folds:
368
            # create and save (unevaluated) predictions across all folds
369
            predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
370
            results_list = predictor.load_saved_predictions(apply_wbc=True)
371
            utils.create_csv_output([(res_dict["boxes"], pid) for res_dict, pid in results_list], cf, logger)
372
            logger.info('starting evaluation...')
373
            cf.fold = 'overall_hold_out'
374
            evaluator = Evaluator(cf, logger, mode='test')
375
            evaluator.evaluate_predictions(results_list)
376
            evaluator.score_test_df()
377
378
        else:
379
            fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
380
                         os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
381
            if folds is None:
382
                folds = range(cf.n_cv_splits)
383
            for fold in folds:
384
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
385
                cf.fold = fold
386
                logger.set_logfile(fold=fold)
387
                if cf.fold_dir in fold_dirs:
388
                    predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
389
                    results_list = predictor.load_saved_predictions(apply_wbc=True)
390
                    logger.info('starting evaluation...')
391
                    evaluator = Evaluator(cf, logger, mode='test')
392
                    evaluator.evaluate_predictions(results_list)
393
                    evaluator.score_test_df()
394
                else:
395
                    logger.info("Skipping fold {} since no model parameters found.".format(fold))
396
397
    # create experiment folder and copy scripts without starting job.
398
    # useful for cloud deployment where configs might change before job actually runs.
399
    elif args.mode == 'create_exp':
400
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, use_stored_settings=False)
401
        logger = utils.get_logger(cf.exp_dir)
402
        logger.info('created experiment directory at {}'.format(cf.exp_dir))
403
404
    else:
405
        raise RuntimeError('mode specified in args is not implemented...')
406
407
408
    t = utils.get_formatted_duration(time.time() - stime)
409
    logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t))
410
    del logger