Diff of /utils/exp_utils.py [000000] .. [bb7f56]

Switch to unified view

a b/utils/exp_utils.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
from typing import Iterable, Tuple, Any, Union
17
import os, sys
18
import subprocess
19
from multiprocessing import Process
20
21
import importlib.util
22
import pickle
23
24
import logging
25
from torch.utils.tensorboard import SummaryWriter
26
27
from collections import OrderedDict
28
import numpy as np
29
import torch
30
import pandas as pd
31
32
def split_off_process(target, *args, daemon: bool=False, **kwargs):
33
    """Start a process that won't block parent script.
34
    No join(), no return value. If daemon=False: before parent exits, it waits for this to finish.
35
    :param target: the target function of the process.
36
    :params *args: args to pass to target.
37
    :param daemon: if False: before parent exits, it waits for this process to finish.
38
    :params **kwargs: kwargs to pass to target.
39
    """
40
    p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=daemon)
41
    p.start()
42
    return p
43
44
def get_formatted_duration(seconds: float, format: str="hms") -> str:
45
    """Format a time in seconds.
46
    :param format: "hms" for hours mins secs or "ms" for min secs.
47
    """
48
    mins, secs = divmod(seconds, 60)
49
    if format == "ms":
50
        t = "{:d}m:{:02d}s".format(int(mins), int(secs))
51
    elif format == "hms":
52
        h, mins = divmod(mins, 60)
53
        t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
54
    else:
55
        raise Exception("Format {} not available, only 'hms' or 'ms'".format(format))
56
    return t
57
58
class CombinedLogger(object):
59
    """Combine console and tensorboard logger and record system metrics.
60
    """
61
62
    def __init__(self, name: str, log_dir: str, server_env: bool=True, fold: Union[int, str]="all"):
63
        self.pylogger = logging.getLogger(name)
64
        self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard"))
65
        self.log_dir = log_dir
66
        self.fold = str(fold)
67
        self.server_env = server_env
68
69
        self.pylogger.setLevel(logging.DEBUG)
70
        self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log')
71
        os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
72
        self.pylogger.addHandler(logging.FileHandler(self.log_file))
73
        if not server_env:
74
            self.pylogger.addHandler(ColorHandler())
75
        else:
76
            self.pylogger.addHandler(logging.StreamHandler())
77
        self.pylogger.propagate = False
78
79
    def __getattr__(self, attr):
80
        """delegate all undefined method requests to objects of
81
        this class in order pylogger, tboard (first find first serve).
82
        E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...)
83
        """
84
        for obj in [self.pylogger, self.tboard]:
85
            if attr in dir(obj):
86
                return getattr(obj, attr)
87
        print("logger attr not found")
88
89
    def set_logfile(self, fold: Union[int, str, None]=None, log_file: Union[str, None]=None):
90
        if fold is not None:
91
            self.fold = str(fold)
92
        if log_file is None:
93
            self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log')
94
        else:
95
            self.log_file = log_file
96
        os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
97
        for hdlr in self.pylogger.handlers:
98
            hdlr.close()
99
        self.pylogger.handlers = []
100
        self.pylogger.addHandler(logging.FileHandler(self.log_file))
101
        if not self.server_env:
102
            self.pylogger.addHandler(ColorHandler())
103
        else:
104
            self.pylogger.addHandler(logging.StreamHandler())
105
106
    def metrics2tboard(self, metrics, global_step=None, suptitle=None):
107
        """
108
        :param metrics: {'train': dataframe, 'val':df}, df as produced in
109
            evaluator.py.evaluate_predictions
110
        """
111
        # print("metrics", metrics)
112
        if global_step is None:
113
            global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1
114
        if suptitle is not None:
115
            suptitle = str(suptitle)
116
        else:
117
            suptitle = "Fold_" + str(self.fold)
118
119
        for key in ['train', 'val']:
120
            # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k}
121
            loss_series = {}
122
            mon_met_series = {}
123
            for tag, val in metrics[key].items():
124
                val = val[-1]  # maybe remove list wrapping, recording in evaluator?
125
                if 'loss' in tag.lower() and not np.isnan(val):
126
                    loss_series["{}".format(tag)] = val
127
                elif not np.isnan(val):
128
                    mon_met_series["{}".format(tag)] = val
129
130
            self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step)
131
            self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step)
132
        self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step)
133
        return
134
135
    def __del__(self):  # otherwise might produce multiple prints e.g. in ipython console
136
        for hdlr in self.pylogger.handlers:
137
            hdlr.close()
138
        self.pylogger.handlers = []
139
        del self.pylogger
140
        self.tboard.flush()
141
        # close somehow prevents main script from exiting
142
        # maybe revise this issue in a later pytorch version
143
        #self.tboard.close()
144
145
146
def get_logger(exp_dir: str, server_env: bool=False) -> CombinedLogger:
147
    """
148
    creates logger instance. writing out info to file, to terminal and to tensorboard.
149
    :param exp_dir: experiment directory, where exec.log file is stored.
150
    :param server_env: True if operating in server environment (e.g., gpu cluster)
151
    :return: custom CombinedLogger instance.
152
    """
153
    log_dir = os.path.join(exp_dir, "logs")
154
    logger = CombinedLogger('medicaldetectiontoolkit', log_dir, server_env=server_env)
155
    print("Logging to {}".format(logger.log_file))
156
    return logger
157
158
159
def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True):
160
    """
161
    I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir.
162
    This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone).
163
    Provides robust structure for cloud deployment.
164
    :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp)
165
    :param exp_path: path to experiment directory.
166
    :param server_env: boolean flag. pass to configs script for cloud deployment.
167
    :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing experiment directory, else creates experiment directory on the fly using configs/model scripts from source code.
168
    :param is_training: boolean flag. distinguishes train vs. inference mode.
169
    :return:
170
    """
171
172
    if is_training:
173
        if use_stored_settings:
174
            cf_file = import_module('cf_file', os.path.join(exp_path, 'configs.py'))
175
            cf = cf_file.configs(server_env)
176
            # in this mode, previously saved model and backbone need to be found in exp dir.
177
            if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \
178
                    not os.path.isfile(os.path.join(exp_path, 'backbone.py')):
179
                raise Exception(
180
                    "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.")
181
            cf.model_path = os.path.join(exp_path, 'model.py')
182
            cf.backbone_path = os.path.join(exp_path, 'backbone.py')
183
        else:
184
            # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model
185
            os.makedirs(exp_path, exist_ok=True)
186
            # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.)
187
            subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')),
188
                            shell=True)
189
            subprocess.call(
190
                'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')),
191
                shell=True)
192
            cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py'))
193
            cf = cf_file.configs(server_env)
194
            subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True)
195
            subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True)
196
            if os.path.isfile(os.path.join(exp_path, "folds_ids.pickle")):
197
                subprocess.call('rm {}'.format(os.path.join(exp_path, "folds_ids.pickle")), shell=True)
198
199
    else:
200
        # testing, use model and backbone stored in exp dir.
201
        cf_file = import_module('cf_file', os.path.join(exp_path, 'configs.py'))
202
        cf = cf_file.configs(server_env)
203
        cf.model_path = os.path.join(exp_path, 'model.py')
204
        cf.backbone_path = os.path.join(exp_path, 'backbone.py')
205
206
207
    cf.exp_dir = exp_path
208
    cf.test_dir = os.path.join(cf.exp_dir, 'test')
209
    cf.plot_dir = os.path.join(cf.exp_dir, 'plots')
210
    if not os.path.exists(cf.test_dir):
211
        os.mkdir(cf.test_dir)
212
    if not os.path.exists(cf.plot_dir):
213
        os.mkdir(cf.plot_dir)
214
    cf.experiment_name = exp_path.split("/")[-1]
215
    cf.created_fold_id_pickle = False
216
217
    return cf
218
219
220
221
def import_module(name: str, path: str):
222
    """
223
    correct way of importing a module dynamically in python 3.
224
    :param name: name given to module instance.
225
    :param path: path to module.
226
    :return: module: returned module instance.
227
    """
228
    spec = importlib.util.spec_from_file_location(name, path)
229
    module = importlib.util.module_from_spec(spec)
230
    spec.loader.exec_module(module)
231
    return module
232
233
234
def set_params_flag(module: torch.nn.Module, flag: Tuple[str, Any], check_overwrite: bool = True) -> torch.nn.Module:
235
    """Set an attribute for all passed module parameters.
236
237
    :param flag: tuple (str attribute name : attr value)
238
    :param check_overwrite: if True, assert that attribute not already exists.
239
240
    """
241
    for param in module.parameters():
242
        if check_overwrite:
243
            assert not hasattr(param, flag[0]), \
244
                "param {} already has attr {} (w/ val {})".format(param, flag[0], getattr(param, flag[0]))
245
        setattr(param, flag[0], flag[1])
246
    return module
247
248
def parse_params_for_optim(net: torch.nn.Module, weight_decay: float = 0., exclude_from_wd: Iterable = ("norm",)) -> list:
249
    """Split network parameters into weight-decay dependent groups for the optimizer.
250
    :param net: network.
251
    :param weight_decay: weight decay value for the parameters that it is applied to. excluded parameters will have
252
        weight decay 0.
253
    :param exclude_from_wd: List of strings of parameter-group names to exclude from weight decay. Options: "norm", "bias".
254
    :return:
255
    """
256
    if weight_decay is None:
257
        weight_decay = 0.
258
    # pytorch implements parameter groups as dicts {'params': ...} and
259
    # weight decay as p.data.mul_(1 - group['lr'] * group['weight_decay'])
260
    norm_types = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d,
261
                  torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d,
262
                  torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.SyncBatchNorm, torch.nn.LocalResponseNorm]
263
    level_map = {"bias": "weight",
264
                 "norm": "module"}
265
    type_map = {"norm": norm_types}
266
267
    exclude_from_wd = [str(name).lower() for name in exclude_from_wd]
268
    exclude_weight_names = [k for k, v in level_map.items() if k in exclude_from_wd and v == "weight"]
269
    exclude_module_types = tuple([type_ for k, v in level_map.items() if (k in exclude_from_wd and v == "module")
270
                                  for type_ in type_map[k]])
271
272
    if exclude_from_wd:
273
        print("excluding {} from weight decay.".format(exclude_from_wd))
274
275
    for module in net.modules():
276
        if isinstance(module, exclude_module_types):
277
            set_params_flag(module, ("no_wd", True))
278
    for param_name, param in net.named_parameters():
279
        if np.any([ename in param_name for ename in exclude_weight_names]):
280
            setattr(param, "no_wd", True)
281
282
    with_dec, no_dec = [], []
283
    for param in net.parameters():
284
        if hasattr(param, "no_wd") and param.no_wd == True:
285
            no_dec.append(param)
286
        else:
287
            with_dec.append(param)
288
    orig_ps = sum(p.numel() for p in net.parameters())
289
    with_ps = sum(p.numel() for p in with_dec)
290
    wo_ps = sum(p.numel() for p in no_dec)
291
    assert orig_ps == with_ps + wo_ps, "orig n parameters {} unequals sum of with wd {} and w/o wd {}."\
292
        .format(orig_ps, with_ps, wo_ps)
293
294
    groups = [{'params': gr, 'weight_decay': wd} for (gr, wd) in [(no_dec, 0.), (with_dec, weight_decay)] if len(gr)>0]
295
    return groups
296
297
298
class ModelSelector:
299
    '''
300
    saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training).
301
    saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance.
302
    '''
303
304
    def __init__(self, cf, logger):
305
306
        self.cf = cf
307
        self.saved_epochs = [-1] * cf.save_n_models
308
        self.logger = logger
309
310
    def run_model_selection(self, net: torch.nn.Module, optimizer: torch.optim.Optimizer,
311
                            monitor_metrics: dict, epoch: int):
312
313
        # take the mean over all selection criteria in each epoch
314
        non_nan_scores = np.mean(np.array([[0 if (ii is None or np.isnan(ii)) else ii for ii in monitor_metrics['val'][sc]] for sc in self.cf.model_selection_criteria]), 0)
315
        epochs_scores = [ii for ii in non_nan_scores[1:]]
316
        # ranking of epochs according to model_selection_criterion
317
        epoch_ranking = np.argsort(epochs_scores, kind="stable")[::-1] + 1 #epochs start at 1
318
        # if set in configs, epochs < min_save_thresh are discarded from saving process.
319
        epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh]
320
321
        # check if current epoch is among the top-k epochs.
322
        if epoch in epoch_ranking[:self.cf.save_n_models]:
323
324
            save_dir = os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(epoch))
325
            if not os.path.exists(save_dir):
326
                os.mkdir(save_dir)
327
328
            torch.save(net.state_dict(), os.path.join(save_dir, 'params.pth'))
329
            with open(os.path.join(save_dir, 'monitor_metrics.pickle'), 'wb') as handle:
330
                pickle.dump(monitor_metrics, handle)
331
            # save epoch_ranking to keep info for inference.
332
            np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
333
            np.save(os.path.join(save_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
334
335
            self.logger.info(
336
                "saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch)))
337
            # delete params of the epoch that just fell out of the top-k epochs.
338
            for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_checkpoint' in ii]:
339
                if se in epoch_ranking[self.cf.save_n_models:]:
340
                    subprocess.call('rm -rf {}'.format(os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(se))), shell=True)
341
                    self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se)))
342
343
        state = {
344
            'epoch': epoch,
345
            'state_dict': net.state_dict(),
346
            'optimizer': optimizer.state_dict(),
347
        }
348
349
        # save checkpoint of current epoch.
350
        save_dir = os.path.join(self.cf.fold_dir, 'last_checkpoint'.format(epoch))
351
        if not os.path.exists(save_dir):
352
            os.mkdir(save_dir)
353
        torch.save(state, os.path.join(save_dir, 'params.pth'))
354
        np.save(os.path.join(save_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
355
        with open(os.path.join(save_dir, 'monitor_metrics.pickle'), 'wb') as handle:
356
            pickle.dump(monitor_metrics, handle)
357
358
359
360
def load_checkpoint(checkpoint_path: str, net: torch.nn.Module, optimizer: torch.optim.Optimizer) -> Tuple:
361
362
    checkpoint = torch.load(os.path.join(checkpoint_path, 'params.pth'))
363
    net.load_state_dict(checkpoint['state_dict'])
364
    optimizer.load_state_dict(checkpoint['optimizer'])
365
    with open(os.path.join(checkpoint_path, 'monitor_metrics.pickle'), 'rb') as handle:
366
        monitor_metrics = pickle.load(handle)
367
    starting_epoch = checkpoint['epoch'] + 1
368
    return starting_epoch, net, optimizer, monitor_metrics
369
370
371
372
def prepare_monitoring(cf):
373
    """
374
    creates dictionaries, where train/val metrics are stored.
375
    """
376
    metrics = {}
377
    # first entry for loss dict accounts for epoch starting at 1.
378
    metrics['train'] = OrderedDict()
379
    metrics['val'] = OrderedDict()
380
    metric_classes = []
381
    if 'rois' in cf.report_score_level:
382
        metric_classes.extend([v for k, v in cf.class_dict.items()])
383
    if 'patient' in cf.report_score_level:
384
        metric_classes.extend(['patient'])
385
    for cl in metric_classes:
386
        metrics['train'][cl + '_ap'] = [np.nan]
387
        metrics['val'][cl + '_ap'] = [np.nan]
388
        if cl == 'patient':
389
            metrics['train'][cl + '_auc'] = [np.nan]
390
            metrics['val'][cl + '_auc'] = [np.nan]
391
392
    return metrics
393
394
395
396
def create_csv_output(results_list, cf, logger):
397
    """
398
    Write out test set predictions to .csv file. output format is one line per prediction:
399
    PatientID | PredictionID | [y1 x1 y2 x2 (z1) (z2)] | score | pred_classID
400
    Note, that prediction coordinates correspond to images as loaded for training/testing and need to be adapted when
401
    plotted over raw data (before preprocessing/resampling).
402
    :param results_list: [[patient_results, patient_id], [patient_results, patient_id], ...]
403
    """
404
405
    logger.info('creating csv output file at {}'.format(os.path.join(cf.test_dir, 'results.csv')))
406
    predictions_df = pd.DataFrame(columns = ['patientID', 'predictionID', 'coords', 'score', 'pred_classID'])
407
    for r in results_list:
408
409
        pid = r[1]
410
411
        #optionally load resampling info from preprocessing to match output predictions with raw data.
412
        #with open(os.path.join(cf.exp_dir, 'test_resampling_info', pid), 'rb') as handle:
413
        #    resampling_info = pickle.load(handle)
414
415
        for bix, box in enumerate(r[0][0]):
416
            if box["box_type"] == "gt":
417
                continue
418
            assert box['box_type'] == 'det', box['box_type']
419
            coords = box['box_coords']
420
            score = box['box_score']
421
            pred_class_id = box['box_pred_class_id']
422
            out_coords = []
423
            if score >= cf.min_det_thresh:
424
                out_coords.append(coords[0]) #* resampling_info['scale'][0])
425
                out_coords.append(coords[1]) #* resampling_info['scale'][1])
426
                out_coords.append(coords[2]) #* resampling_info['scale'][0])
427
                out_coords.append(coords[3]) #* resampling_info['scale'][1])
428
                if len(coords) > 4:
429
                    out_coords.append(coords[4]) #* resampling_info['scale'][2] + resampling_info['z_crop'])
430
                    out_coords.append(coords[5]) #* resampling_info['scale'][2] + resampling_info['z_crop'])
431
432
                predictions_df.loc[len(predictions_df)] = [pid, bix, out_coords, score, pred_class_id]
433
    try:
434
        fold = cf.fold
435
    except:
436
        fold = 'hold_out'
437
    predictions_df.to_csv(os.path.join(cf.exp_dir, 'results_{}.csv'.format(fold)), index=False)
438
439
440
441
class _AnsiColorizer(object):
442
    """
443
    A colorizer is an object that loosely wraps around a stream, allowing
444
    callers to write text to the stream in a particular color.
445
446
    Colorizer classes must implement C{supported()} and C{write(text, color)}.
447
    """
448
    _colors = dict(black=30, red=31, green=32, yellow=33,
449
                   blue=34, magenta=35, cyan=36, white=37, default=39)
450
451
    def __init__(self, stream):
452
        self.stream = stream
453
454
    @classmethod
455
    def supported(cls, stream=sys.stdout):
456
        """
457
        A class method that returns True if the current platform supports
458
        coloring terminal output using this method. Returns False otherwise.
459
        """
460
        if not stream.isatty():
461
            return False  # auto color only on TTYs
462
        try:
463
            import curses
464
        except ImportError:
465
            return False
466
        else:
467
            try:
468
                try:
469
                    return curses.tigetnum("colors") > 2
470
                except curses.error:
471
                    curses.setupterm()
472
                    return curses.tigetnum("colors") > 2
473
            except:
474
                raise
475
                # guess false in case of error
476
                return False
477
478
    def write(self, text, color):
479
        """
480
        Write the given text to the stream in the given color.
481
482
        @param text: Text to be written to the stream.
483
484
        @param color: A string label for a color. e.g. 'red', 'white'.
485
        """
486
        color = self._colors[color]
487
        self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text))
488
489
490
491
class ColorHandler(logging.StreamHandler):
492
493
494
    def __init__(self, stream=sys.stdout):
495
        super(ColorHandler, self).__init__(_AnsiColorizer(stream))
496
497
    def emit(self, record):
498
        msg_colors = {
499
            logging.DEBUG: "green",
500
            logging.INFO: "default",
501
            logging.WARNING: "red",
502
            logging.ERROR: "red"
503
        }
504
        color = msg_colors.get(record.levelno, "blue")
505
        self.stream.write(record.msg + "\n", color)
506