a b/experiments/bleed_exp/default_configs.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
17
"""Default Configurations script. Avoids changing configs of all experiments if general settings are to be changed."""
18
19
import os
20
21
class DefaultConfigs:
22
23
    def __init__(self, model, server_env=None, dim=2):
24
        self.server_env = server_env
25
        #########################
26
        #         I/O           #
27
        #########################
28
29
        self.model = model
30
        self.dim = dim
31
        # int [0 < dataset_size]. select n patients from dataset for prototyping.
32
        self.select_prototype_subset = None
33
34
        # some default paths.
35
        self.backbone_path = 'models/backbone.py'
36
        self.source_dir = os.path.dirname(os.path.realpath(__file__)) #current dir.
37
        self.input_df_name = 'info_df.pickle'
38
        self.model_path = 'models/{}.py'.format(self.model)
39
40
        if server_env:
41
            self.source_dir = '/home/jaegerp/code/mamma_code/medicaldetectiontoolkit'
42
43
44
        #########################
45
        #      Data Loader      #
46
        #########################
47
48
        #random seed for fold_generator and batch_generator.
49
        self.seed = 0
50
51
        #number of threads for multithreaded batch generation.
52
        self.n_workers = 4 if server_env else os.cpu_count()-1
53
54
        # if True, segmentation losses learn all categories, else only foreground vs. background.
55
        self.class_specific_seg_flag = False
56
57
        #########################
58
        #      Architecture      #
59
        #########################
60
61
        self.weight_decay = 0.0
62
63
        # nonlinearity to be applied after convs with nonlinearity. one of 'relu' or 'leaky_relu'
64
        self.relu = 'relu'
65
66
        # if True initializes weights as specified in model script. else use default Pytorch init.
67
        self.custom_init = False
68
69
        # if True adds high-res decoder levels to feature pyramid: P1 + P0. (e.g. set to true in retina_unet configs)
70
        self.operate_stride1 = False
71
72
        #########################
73
        #  Schedule             #
74
        #########################
75
76
        # number of folds in cross validation.
77
        self.n_cv_splits = 5
78
79
80
        # number of probabilistic samples in validation.
81
        self.n_probabilistic_samples = None
82
83
        #########################
84
        #   Testing / Plotting  #
85
        #########################
86
87
        # perform mirroring at test time. (only XY. Z not done to not blow up predictions times).
88
        self.test_aug = True
89
90
        # if True, test data lies in a separate folder and is not part of the cross validation.
91
        self.hold_out_test_set = False
92
93
        # if hold_out_test_set provided, ensemble predictions over models of all trained cv-folds.
94
        self.ensemble_folds = False
95
96
        # color specifications for all box_types in prediction_plot.
97
        self.box_color_palette = {'det': 'b', 'gt': 'r', 'neg_class': 'purple',
98
                                  'prop': 'w', 'pos_class': 'g', 'pos_anchor': 'c', 'neg_anchor': 'c'}
99
100
        # scan over confidence score in evaluation to optimize it on the validation set.
101
        self.scan_det_thresh = False
102
103
        # plots roc-curves / prc-curves in evaluation.
104
        self.plot_stat_curves = False
105
106
        # evaluates average precision per image and averages over images. instead computing one ap over data set.
107
        self.per_patient_ap = False
108
109
        # threshold for clustering 2D box predictions to 3D Cubes. Overlap is computed in XY.
110
        self.merge_3D_iou = 0.1
111
112
        # monitor any value from training.
113
        self.n_monitoring_figures = 1
114
        # dict to assign specific plot_values to monitor_figures > 0. {1: ['class_loss'], 2: ['kl_loss', 'kl_sigmas']}
115
        self.assign_values_to_extra_figure = {}
116
117
        # save predictions to csv file in experiment dir.
118
        self.save_preds_to_csv = True
119
120
        # select a maximum number of patient cases to test. number or "all" for all
121
        self.max_test_patients = "all"
122
123
        #########################
124
        #   MRCNN               #
125
        #########################
126
127
        # if True, mask loss is not applied. used for data sets, where no pixel-wise annotations are provided.
128
        self.frcnn_mode = False
129
130
        # if True, unmolds masks in Mask R-CNN to full-res for plotting/monitoring.
131
        self.return_masks_in_val = False
132
        self.return_masks_in_test = False # needed if doing instance segmentation. evaluation not yet implemented.
133
134
        # add P6 to Feature Pyramid Network.
135
        self.sixth_pooling = False
136
137
        # for probabilistic detection
138
        self.n_latent_dims = 0
139
140