Diff of /medseg_dl/parameters.py [000000] .. [6d4aaa]

Switch to unified view

a b/medseg_dl/parameters.py
1
import types
2
import yaml
3
import logging
4
import os
5
import tempfile
6
import shutil
7
import datetime
8
9
"""Parameter file"""
10
11
12
def fetch_options(idx_dataset=0):
13
    """Define default values for parameters used within tf"""
14
    op = types.SimpleNamespace()
15
16
    # setup
17
    op.device = '0'  # standard device, can be bypassed if device is provided during call
18
19
    # i/o
20
    op.path_parser_cfg = '/home/io_patterns_nako.yaml'
21
22
    if idx_dataset == 0:
23
        op.dir_data = '/home/NAKO/'
24
    else:
25
        raise ValueError('Chosen dataset idx is not available')
26
27
    op.set_split = 70
28
29
    # pipeline & augmentation
30
    op.shape_image = [320, 260, 316]
31
    op.shape_image_eval = [160, 160, 160]  # smaller image results in faster eval time (less patches) old:[240, 240, 240], i have changed it to [160, 160, 160]
32
    op.shape_input = [64, 64, 64]  # has to correspond to your designed model
33
    op.shape_output = [16, 16, 16]  # has to correspond to your designed model
34
    op.size_batch = 24
35
    op.size_batch_eval = 50
36
    op.size_buffer = 1
37
    op.num_parallel_calls = 4
38
    op.repeat = 5
39
    op.b_shuffle = True
40
    op.b_eval_labels_patch = False
41
    op.b_eval_labels_image = True
42
43
    op.patches_per_class = [2, 1, 1, 1, 1, 1]  # Note: atm has to be reflected in your input pipeline #2
44
45
    # patch augmentation
46
    op.sigma_offset = 0.1
47
    op.sigma_noise = 0.05
48
    op.sigma_pos = 0.08
49
50
    # image augmentation
51
    op.b_mirror = False
52
    op.b_rotate = True
53
    op.b_scale = True
54
    op.b_warp = False
55
    op.b_permute_labels = False
56
    op.angle_max = 7
57
    op.scale_factor = 0.08
58
    op.delta_max = 0
59
60
    # model
61
    op.channels = 2
62
    op.channels_out = 6
63
    op.b_dynamic_pos_mid = True
64
    op.b_dynamic_pos_end = False
65
    op.filters = 32
66
    op.dense_layers = 2
67
    op.alpha = 0.2
68
    op.rate_dropout = 0.0
69
70
    # optimizer
71
    op.rate_learning = 0.00001  # std: 1e-2 - 1e-6, too low: slow learning
72
    op.beta1 = 0.9  # std: 0.9
73
    op.beta2 = 0.999  # std: 0.999
74
    op.epsilon = 0.00000001  # std: 1e-8, too high: slow learning
75
76
    # session
77
    op.num_epochs = 701
78
    op.b_continuous_eval = True
79
    op.b_restore = False
80
    op.save_summary_steps = 1
81
    op.b_viewer_train = False
82
    op.b_viewer_eval = False
83
    op.b_save_pred = False
84
85
    # logging
86
    op.log_level = logging.INFO
87
88
    # seed
89
    op.b_use_seed = False
90
    op.random_seed = 100
91
92
    return op
93
94
95
class Params(object):
96
97
    def __init__(self, path_yaml='', model_dir='', idx_dataset=-1, b_recreate=False):
98
99
        # set yaml path:
100
        self.path_yaml = path_yaml
101
102
        # fetch dataset idx to choose hardcoded split
103
        self.idx_dataset = idx_dataset
104
105
        # add passed/generated params file
106
        self.update(b_recreate)
107
108
        # add default values
109
        if model_dir:
110
            self.set_model_dir(model_dir)
111
112
    def create(self):
113
        # Create new default params
114
        logging.info('Creating a new set of parameters in %s', self.path_yaml)
115
        self.__dict__.update(fetch_options(idx_dataset=self.idx_dataset).__dict__)
116
        self.save()
117
118
    def save(self):
119
        if not self.path_yaml:
120
            _, self.path_yaml = tempfile.mkstemp()
121
        try:
122
            with open(self.path_yaml, 'w') as file:
123
                yaml.dump(self.__dict__, file, indent=4)
124
        except Exception:
125
            os.remove(self.path_yaml)
126
            raise Exception()
127
128
    def update(self, b_recreate):
129
        if b_recreate or not os.path.isfile(self.path_yaml):
130
            self.create()
131
132
        with open(self.path_yaml, 'r') as file:
133
            params = yaml.load(file)
134
            self.__dict__.update(params)
135
136
    def set_path(self, path_yaml):
137
        self.path_yaml = path_yaml
138
139
    def set_model_dir(self, model_dir):
140
        self.__dict__['date'] = datetime.datetime.now().strftime('_%Y-%m-%dT%H-%M-%S')
141
        self.__dict__['dir_model'] = os.path.join(model_dir, 'run' + self.__dict__['date'])
142
        self.__dict__['dir_logs_train'] = os.path.join(self.__dict__['dir_model'], 'logs', 'train')
143
        self.__dict__['dir_logs_eval'] = os.path.join(self.__dict__['dir_model'], 'logs', 'eval')
144
        self.__dict__['dir_graphs_train'] = os.path.join(self.__dict__['dir_model'], 'graphs', 'train')
145
        self.__dict__['dir_graphs_eval'] = os.path.join(self.__dict__['dir_model'], 'graphs', 'eval')
146
        self.__dict__['dir_ckpts'] = os.path.join(self.__dict__['dir_model'], 'ckpts')
147
        self.__dict__['dir_ckpts_best'] = os.path.join(self.__dict__['dir_model'], 'ckpts_best')
148
149
        # Create environment
150
        for k, v in self.__dict__.items():
151
            if 'dir' in k:
152
                if not os.path.exists(v):
153
                    os.makedirs(v, exist_ok=True)
154
155
        # Move params to corresponding model folder
156
        self.move(os.path.join(self.__dict__['dir_model'], 'params.yaml'))
157
158
        # Save with dirs
159
        self.save()
160
161
    def move(self, path_yaml_new):
162
        shutil.move(self.path_yaml, path_yaml_new)
163
        self.path_yaml = path_yaml_new
164
165
    @property
166
    def dict(self):
167
        """Gives dict-like access to Params instance by `params.dict['learning_rate']`"""
168
        return self.__dict__