Switch to unified view

a b/utils/default_config_setup.py
1
import json
2
import os
3
from enum import Enum
4
5
from dataloaders.BRAINWEB import BRAINWEB
6
from dataloaders.MSISBI2015 import MSISBI2015
7
from dataloaders.MSLUB import MSLUB
8
from dataloaders.MSSEG2008 import MSSEG2008
9
10
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
12
13
class Dataset(Enum):
14
    BRAINWEB = 'BRAINWEBDIR'
15
    MSSEG2008_UNC = 'MSSEG2008DIR'
16
    MSSEG2008_CHB = 'MSSEG2008DIR'
17
    MSISBI2015 = 'MSISBI2015DIR'
18
    MSLUB = 'MSLUBDIR'
19
20
21
def get_options(batchsize, learningrate, numEpochs, zDim, outputWidth, outputHeight, slices_start=20, slices_end=130, numMonteCarloSamples=0, config=None):
22
    options = {}
23
    # Load config.json, which should hold DATADIR, CHECKPOINTDIR and SAMPLEDIR
24
    if config:
25
        options["globals"] = config
26
    else:
27
        with open(os.path.join(base_path, "config.default.json"), 'r') as f:
28
            options["globals"] = json.load(f)
29
30
    # Options
31
    options['debug'] = False
32
    options['data'] = {}
33
    options['train'] = {}
34
    options['train']['checkpointDir'] = options["globals"]["CHECKPOINTDIR"]
35
    options['train']['samplesDir'] = options["globals"]["SAMPLEDIR"]
36
    options['train']['batchsize'] = batchsize
37
    options['train']['learningrate'] = learningrate
38
    options['train']['numEpochs'] = numEpochs
39
    options['train']['zDim'] = zDim
40
    options['train']['snapshotAfter'] = 1000  # Take a snapshot after every 50 iterations
41
    options['train']['outputWidth'] = outputWidth
42
    options['train']['outputHeight'] = outputHeight
43
    options['train']['useTensorboard'] = True
44
    options['train']['useMatplotlib'] = False
45
    options['train']['tensorboardPort'] = 9001
46
    options['sliceStart'] = slices_start  # 20
47
    options['sliceEnd'] = slices_end  # 130
48
    options['threshold'] = 'bestdice'
49
    options['exportVolumes'] = False
50
    options['exportPRC'] = True
51
    options['exportROC'] = True
52
    options['numMonteCarloSamples'] = numMonteCarloSamples
53
    options['keepOnlyPositiveResiduals'] = True
54
    options['applyHyperIntensityPrior'] = True
55
    options['medianFiltering'] = True
56
    options['erodeBrainmask'] = True
57
    return options
58
59
60
def get_datasets(options, dataset: Dataset = Dataset.BRAINWEB):
61
    if dataset == Dataset.BRAINWEB:
62
        return get_Brainweb_healthy_dataset(options), get_Brainweb_lesion_dataset(options)
63
    elif dataset == Dataset.MSSEG2008_UNC:
64
        return None, get_MSSEG2008_dataset(options, 'UNC')
65
    elif dataset == Dataset.MSSEG2008_CHB:
66
        return None, get_MSSEG2008_dataset(options, 'CHB')
67
    elif dataset == Dataset.MSISBI2015:
68
        return None, get_MSISBI2015_dataset(options)
69
    elif dataset == Dataset.MSLUB:
70
        return None, get_MSLUB_dataset(options)
71
    else:
72
        raise ValueError(f'No valid dataset given: {dataset}')
73
74
75
###########################
76
#       MSSEG2008         #
77
###########################
78
def get_MSSEG2008_dataset(options, filter_sanner):
79
    dataset_options = get_MSSEG2008_dataset_options(options, filter_sanner)
80
    dataset = MSSEG2008(dataset_options)
81
    if options['debug']:
82
        dataset.visualize()
83
84
    return dataset
85
86
87
def get_MSSEG2008_dataset_options(options, filter_sanner):
88
    dataset_options = MSSEG2008.Options()
89
    dataset_options.description = ''
90
    dataset_options.debug = options['debug']
91
    dataset_options.dir = options['globals']['MSSEG2008DIR']
92
    dataset_options.useCrops = False
93
    dataset_options.cropType = 'center'  # Crop patches around lesions
94
    dataset_options.cropWidth = options['train']['outputWidth']
95
    dataset_options.cropHeight = options['train']['outputHeight']
96
    dataset_options.numRandomCropsPerSlice = 5  # Not needed when doing center crops
97
    dataset_options.rotations = [0]
98
    dataset_options.partition = {'TRAIN': 0, 'VAL': 2, 'TEST': 8}
99
    dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']]
100
    dataset_options.cache = True
101
    dataset_options.numSamples = -1
102
    dataset_options.addInstanceNoise = False
103
    dataset_options.axis = 'axial'
104
    dataset_options.filterScanner = filter_sanner  # 'UNC'or 'CHB'
105
    dataset_options.filterProtocols = ['FLAIR']
106
    dataset_options.filterType = "train"
107
    dataset_options.normalizationMethod = 'scaling'
108
    dataset_options.skullStripping = True
109
    dataset_options.sliceStart = options['sliceStart']
110
    dataset_options.sliceEnd = options['sliceEnd']
111
    dataset_options.skullStripping = True
112
    dataset_options.format = "aligned"
113
114
    return dataset_options
115
116
117
###########################
118
#       MSISBI2015        #
119
###########################
120
def get_MSISBI2015_dataset(options):
121
    dataset_options = get_MSISBI2015_dataset_options(options)
122
    dataset = MSISBI2015(dataset_options)
123
    if options['debug']:
124
        dataset.visualize()
125
126
    return dataset
127
128
129
def get_MSISBI2015_dataset_options(options):
130
    dataset_options = MSISBI2015.Options()
131
    dataset_options.description = ''
132
    dataset_options.debug = options['debug']
133
    dataset_options.dir = options['globals']['MSISBI2015DIR']
134
    dataset_options.useCrops = False
135
    dataset_options.cropType = 'center'  # Crop patches around lesions
136
    dataset_options.cropWidth = options['train']['outputWidth']
137
    dataset_options.cropHeight = options['train']['outputHeight']
138
    dataset_options.numRandomCropsPerSlice = 5  # Not needed when doing center crops
139
    dataset_options.rotations = [0]
140
    dataset_options.partition = {'TRAIN': 0, 'VAL': 5, 'TEST': 15}
141
    dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']]
142
    dataset_options.cache = True
143
    dataset_options.numSamples = -1
144
    dataset_options.addInstanceNoise = False
145
    dataset_options.axis = 'axial'
146
    dataset_options.filterProtocols = ['FLAIR']
147
    dataset_options.filterType = "train"
148
    dataset_options.normalizationMethod = 'scaling'
149
    dataset_options.skullStripping = True
150
    dataset_options.sliceStart = options['sliceStart']
151
    dataset_options.sliceEnd = options['sliceEnd']
152
    dataset_options.skullStripping = True
153
    dataset_options.format = "aligned"
154
    return dataset_options
155
156
157
###########################
158
#       MSLUB        #
159
###########################
160
def get_MSLUB_dataset(options):
161
    dataset_options = get_MSLUB_dataset_options(options)
162
    dataset = MSLUB(dataset_options)
163
    if options['debug']:
164
        dataset.visualize()
165
166
    return dataset
167
168
169
def get_MSLUB_dataset_options(options):
170
    dataset_options = MSLUB.Options()
171
    dataset_options.description = ''
172
    dataset_options.debug = options['debug']
173
    dataset_options.dir = options['globals']['MSLUBDIR']
174
    dataset_options.useCrops = False
175
    dataset_options.cropType = 'center'  # Crop patches around lesions
176
    dataset_options.cropWidth = options['train']['outputWidth']
177
    dataset_options.cropHeight = options['train']['outputHeight']
178
    dataset_options.numRandomCropsPerSlice = 5  # Not needed when doing center crops
179
    dataset_options.rotations = [0]
180
    dataset_options.partition = {'TRAIN': 0, 'VAL': 5, 'TEST': 25}
181
    dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']]
182
    dataset_options.cache = True
183
    dataset_options.numSamples = -1
184
    dataset_options.addInstanceNoise = False
185
    dataset_options.axis = 'axial'
186
    dataset_options.filterProtocols = ['FLAIR']
187
    dataset_options.normalizationMethod = 'scaling'
188
    dataset_options.skullStripping = True
189
    dataset_options.sliceStart = options['sliceStart']
190
    dataset_options.sliceEnd = options['sliceEnd']
191
    dataset_options.skullStripping = True
192
    dataset_options.format = "aligned"
193
194
    return dataset_options
195
196
197
#######################
198
#      Brainweb       #
199
#######################
200
def get_Brainweb_healthy_dataset(options):
201
    dataset_options = get_Brainweb_dataset_options(options)
202
    dataset_hc = BRAINWEB(dataset_options)
203
    if options['debug']:
204
        dataset_hc.visualize()
205
    return dataset_hc
206
207
208
def get_Brainweb_lesion_dataset(options):
209
    dataset_options = get_Brainweb_dataset_options(options)
210
    # Center Crops of slices from patients with lesions. Only for testing
211
    dataset_options.partition = {'TRAIN': 0.0, 'VAL': 0.0, 'TEST': 1.0}
212
    dataset_options.filterType = 'SEVEREMS'
213
    dataset_options.rotations = [0]
214
    return BRAINWEB(dataset_options)
215
216
217
def get_Brainweb_dataset_options(options):
218
    dataset_options = BRAINWEB.Options()
219
    dataset_options.description = ""
220
    dataset_options.debug = options['debug']
221
    dataset_options.dir = options['data']['dir']
222
    dataset_options.useCrops = False
223
    dataset_options.cropType = 'center'  # Not used when useCrops is False
224
    dataset_options.cropWidth = options['train']['outputWidth']
225
    dataset_options.cropHeight = options['train']['outputHeight']
226
    dataset_options.numRandomCropsPerSlice = 5  # Not needed when doing center crops
227
    dataset_options.rotations = [0]
228
    dataset_options.partition = {'TRAIN': 0.7, 'VAL': 0.3, 'TEST': 0.0}
229
    dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']]
230
    dataset_options.cache = True
231
    dataset_options.numSamples = -1
232
    dataset_options.addInstanceNoise = False
233
    dataset_options.axis = 'axial'
234
    dataset_options.filterType = 'NORMAL'
235
    dataset_options.filterProtocol = 'T2'
236
    dataset_options.normalizationMethod = 'scaling'
237
    dataset_options.skullRemoval = True
238
    dataset_options.sliceStart = options['sliceStart']
239
    dataset_options.sliceEnd = options['sliceEnd']
240
    dataset_options.backgroundRemoval = True
241
    dataset_options.registerTo = None
242
    return dataset_options
243
244
245
def get_config(trainer, options, optimizer, intermediateResolutions, dropout_rate, dataset):
246
    config = trainer.Config()
247
    config.dataset = type(dataset).__name__
248
    config.description = ''
249
    config.numChannels = dataset.num_channels
250
    config.batchsize = options['train']['batchsize']
251
    config.checkpointDir = options['train']['checkpointDir']
252
    config.snapShotAfter = options['train']['snapshotAfter']
253
    config.sampleDir = options['train']['samplesDir']
254
    config.learningrate = options['train']['learningrate']
255
    config.numEpochs = options['train']['numEpochs']
256
    config.zDim = options['train']['zDim']
257
    config.beta1 = 0.5
258
    config.outputHeight = options['train']['outputHeight']
259
    config.outputWidth = options['train']['outputWidth']
260
    config.useTensorboard = options['train']['useTensorboard']
261
    config.useMatplotlib = options['train']['useMatplotlib']
262
    config.tensorboardPort = options['train']['tensorboardPort']
263
    config.debugGradients = options['debug']
264
    config.optimizer = optimizer
265
    config.intermediateResolutions = intermediateResolutions
266
    config.weightRegularization = 0.0
267
    config.dropout_rate = dropout_rate
268
    config.dropout = False
269
    config.l1_weight = 1.0
270
    config.options = options
271
    return config