|
a |
|
b/utils/default_config_setup.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
from enum import Enum |
|
|
4 |
|
|
|
5 |
from dataloaders.BRAINWEB import BRAINWEB |
|
|
6 |
from dataloaders.MSISBI2015 import MSISBI2015 |
|
|
7 |
from dataloaders.MSLUB import MSLUB |
|
|
8 |
from dataloaders.MSSEG2008 import MSSEG2008 |
|
|
9 |
|
|
|
10 |
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class Dataset(Enum): |
|
|
14 |
BRAINWEB = 'BRAINWEBDIR' |
|
|
15 |
MSSEG2008_UNC = 'MSSEG2008DIR' |
|
|
16 |
MSSEG2008_CHB = 'MSSEG2008DIR' |
|
|
17 |
MSISBI2015 = 'MSISBI2015DIR' |
|
|
18 |
MSLUB = 'MSLUBDIR' |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def get_options(batchsize, learningrate, numEpochs, zDim, outputWidth, outputHeight, slices_start=20, slices_end=130, numMonteCarloSamples=0, config=None): |
|
|
22 |
options = {} |
|
|
23 |
# Load config.json, which should hold DATADIR, CHECKPOINTDIR and SAMPLEDIR |
|
|
24 |
if config: |
|
|
25 |
options["globals"] = config |
|
|
26 |
else: |
|
|
27 |
with open(os.path.join(base_path, "config.default.json"), 'r') as f: |
|
|
28 |
options["globals"] = json.load(f) |
|
|
29 |
|
|
|
30 |
# Options |
|
|
31 |
options['debug'] = False |
|
|
32 |
options['data'] = {} |
|
|
33 |
options['train'] = {} |
|
|
34 |
options['train']['checkpointDir'] = options["globals"]["CHECKPOINTDIR"] |
|
|
35 |
options['train']['samplesDir'] = options["globals"]["SAMPLEDIR"] |
|
|
36 |
options['train']['batchsize'] = batchsize |
|
|
37 |
options['train']['learningrate'] = learningrate |
|
|
38 |
options['train']['numEpochs'] = numEpochs |
|
|
39 |
options['train']['zDim'] = zDim |
|
|
40 |
options['train']['snapshotAfter'] = 1000 # Take a snapshot after every 50 iterations |
|
|
41 |
options['train']['outputWidth'] = outputWidth |
|
|
42 |
options['train']['outputHeight'] = outputHeight |
|
|
43 |
options['train']['useTensorboard'] = True |
|
|
44 |
options['train']['useMatplotlib'] = False |
|
|
45 |
options['train']['tensorboardPort'] = 9001 |
|
|
46 |
options['sliceStart'] = slices_start # 20 |
|
|
47 |
options['sliceEnd'] = slices_end # 130 |
|
|
48 |
options['threshold'] = 'bestdice' |
|
|
49 |
options['exportVolumes'] = False |
|
|
50 |
options['exportPRC'] = True |
|
|
51 |
options['exportROC'] = True |
|
|
52 |
options['numMonteCarloSamples'] = numMonteCarloSamples |
|
|
53 |
options['keepOnlyPositiveResiduals'] = True |
|
|
54 |
options['applyHyperIntensityPrior'] = True |
|
|
55 |
options['medianFiltering'] = True |
|
|
56 |
options['erodeBrainmask'] = True |
|
|
57 |
return options |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
def get_datasets(options, dataset: Dataset = Dataset.BRAINWEB): |
|
|
61 |
if dataset == Dataset.BRAINWEB: |
|
|
62 |
return get_Brainweb_healthy_dataset(options), get_Brainweb_lesion_dataset(options) |
|
|
63 |
elif dataset == Dataset.MSSEG2008_UNC: |
|
|
64 |
return None, get_MSSEG2008_dataset(options, 'UNC') |
|
|
65 |
elif dataset == Dataset.MSSEG2008_CHB: |
|
|
66 |
return None, get_MSSEG2008_dataset(options, 'CHB') |
|
|
67 |
elif dataset == Dataset.MSISBI2015: |
|
|
68 |
return None, get_MSISBI2015_dataset(options) |
|
|
69 |
elif dataset == Dataset.MSLUB: |
|
|
70 |
return None, get_MSLUB_dataset(options) |
|
|
71 |
else: |
|
|
72 |
raise ValueError(f'No valid dataset given: {dataset}') |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
########################### |
|
|
76 |
# MSSEG2008 # |
|
|
77 |
########################### |
|
|
78 |
def get_MSSEG2008_dataset(options, filter_sanner): |
|
|
79 |
dataset_options = get_MSSEG2008_dataset_options(options, filter_sanner) |
|
|
80 |
dataset = MSSEG2008(dataset_options) |
|
|
81 |
if options['debug']: |
|
|
82 |
dataset.visualize() |
|
|
83 |
|
|
|
84 |
return dataset |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def get_MSSEG2008_dataset_options(options, filter_sanner): |
|
|
88 |
dataset_options = MSSEG2008.Options() |
|
|
89 |
dataset_options.description = '' |
|
|
90 |
dataset_options.debug = options['debug'] |
|
|
91 |
dataset_options.dir = options['globals']['MSSEG2008DIR'] |
|
|
92 |
dataset_options.useCrops = False |
|
|
93 |
dataset_options.cropType = 'center' # Crop patches around lesions |
|
|
94 |
dataset_options.cropWidth = options['train']['outputWidth'] |
|
|
95 |
dataset_options.cropHeight = options['train']['outputHeight'] |
|
|
96 |
dataset_options.numRandomCropsPerSlice = 5 # Not needed when doing center crops |
|
|
97 |
dataset_options.rotations = [0] |
|
|
98 |
dataset_options.partition = {'TRAIN': 0, 'VAL': 2, 'TEST': 8} |
|
|
99 |
dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']] |
|
|
100 |
dataset_options.cache = True |
|
|
101 |
dataset_options.numSamples = -1 |
|
|
102 |
dataset_options.addInstanceNoise = False |
|
|
103 |
dataset_options.axis = 'axial' |
|
|
104 |
dataset_options.filterScanner = filter_sanner # 'UNC'or 'CHB' |
|
|
105 |
dataset_options.filterProtocols = ['FLAIR'] |
|
|
106 |
dataset_options.filterType = "train" |
|
|
107 |
dataset_options.normalizationMethod = 'scaling' |
|
|
108 |
dataset_options.skullStripping = True |
|
|
109 |
dataset_options.sliceStart = options['sliceStart'] |
|
|
110 |
dataset_options.sliceEnd = options['sliceEnd'] |
|
|
111 |
dataset_options.skullStripping = True |
|
|
112 |
dataset_options.format = "aligned" |
|
|
113 |
|
|
|
114 |
return dataset_options |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
########################### |
|
|
118 |
# MSISBI2015 # |
|
|
119 |
########################### |
|
|
120 |
def get_MSISBI2015_dataset(options): |
|
|
121 |
dataset_options = get_MSISBI2015_dataset_options(options) |
|
|
122 |
dataset = MSISBI2015(dataset_options) |
|
|
123 |
if options['debug']: |
|
|
124 |
dataset.visualize() |
|
|
125 |
|
|
|
126 |
return dataset |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def get_MSISBI2015_dataset_options(options): |
|
|
130 |
dataset_options = MSISBI2015.Options() |
|
|
131 |
dataset_options.description = '' |
|
|
132 |
dataset_options.debug = options['debug'] |
|
|
133 |
dataset_options.dir = options['globals']['MSISBI2015DIR'] |
|
|
134 |
dataset_options.useCrops = False |
|
|
135 |
dataset_options.cropType = 'center' # Crop patches around lesions |
|
|
136 |
dataset_options.cropWidth = options['train']['outputWidth'] |
|
|
137 |
dataset_options.cropHeight = options['train']['outputHeight'] |
|
|
138 |
dataset_options.numRandomCropsPerSlice = 5 # Not needed when doing center crops |
|
|
139 |
dataset_options.rotations = [0] |
|
|
140 |
dataset_options.partition = {'TRAIN': 0, 'VAL': 5, 'TEST': 15} |
|
|
141 |
dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']] |
|
|
142 |
dataset_options.cache = True |
|
|
143 |
dataset_options.numSamples = -1 |
|
|
144 |
dataset_options.addInstanceNoise = False |
|
|
145 |
dataset_options.axis = 'axial' |
|
|
146 |
dataset_options.filterProtocols = ['FLAIR'] |
|
|
147 |
dataset_options.filterType = "train" |
|
|
148 |
dataset_options.normalizationMethod = 'scaling' |
|
|
149 |
dataset_options.skullStripping = True |
|
|
150 |
dataset_options.sliceStart = options['sliceStart'] |
|
|
151 |
dataset_options.sliceEnd = options['sliceEnd'] |
|
|
152 |
dataset_options.skullStripping = True |
|
|
153 |
dataset_options.format = "aligned" |
|
|
154 |
return dataset_options |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
########################### |
|
|
158 |
# MSLUB # |
|
|
159 |
########################### |
|
|
160 |
def get_MSLUB_dataset(options): |
|
|
161 |
dataset_options = get_MSLUB_dataset_options(options) |
|
|
162 |
dataset = MSLUB(dataset_options) |
|
|
163 |
if options['debug']: |
|
|
164 |
dataset.visualize() |
|
|
165 |
|
|
|
166 |
return dataset |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
def get_MSLUB_dataset_options(options): |
|
|
170 |
dataset_options = MSLUB.Options() |
|
|
171 |
dataset_options.description = '' |
|
|
172 |
dataset_options.debug = options['debug'] |
|
|
173 |
dataset_options.dir = options['globals']['MSLUBDIR'] |
|
|
174 |
dataset_options.useCrops = False |
|
|
175 |
dataset_options.cropType = 'center' # Crop patches around lesions |
|
|
176 |
dataset_options.cropWidth = options['train']['outputWidth'] |
|
|
177 |
dataset_options.cropHeight = options['train']['outputHeight'] |
|
|
178 |
dataset_options.numRandomCropsPerSlice = 5 # Not needed when doing center crops |
|
|
179 |
dataset_options.rotations = [0] |
|
|
180 |
dataset_options.partition = {'TRAIN': 0, 'VAL': 5, 'TEST': 25} |
|
|
181 |
dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']] |
|
|
182 |
dataset_options.cache = True |
|
|
183 |
dataset_options.numSamples = -1 |
|
|
184 |
dataset_options.addInstanceNoise = False |
|
|
185 |
dataset_options.axis = 'axial' |
|
|
186 |
dataset_options.filterProtocols = ['FLAIR'] |
|
|
187 |
dataset_options.normalizationMethod = 'scaling' |
|
|
188 |
dataset_options.skullStripping = True |
|
|
189 |
dataset_options.sliceStart = options['sliceStart'] |
|
|
190 |
dataset_options.sliceEnd = options['sliceEnd'] |
|
|
191 |
dataset_options.skullStripping = True |
|
|
192 |
dataset_options.format = "aligned" |
|
|
193 |
|
|
|
194 |
return dataset_options |
|
|
195 |
|
|
|
196 |
|
|
|
197 |
####################### |
|
|
198 |
# Brainweb # |
|
|
199 |
####################### |
|
|
200 |
def get_Brainweb_healthy_dataset(options): |
|
|
201 |
dataset_options = get_Brainweb_dataset_options(options) |
|
|
202 |
dataset_hc = BRAINWEB(dataset_options) |
|
|
203 |
if options['debug']: |
|
|
204 |
dataset_hc.visualize() |
|
|
205 |
return dataset_hc |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
def get_Brainweb_lesion_dataset(options): |
|
|
209 |
dataset_options = get_Brainweb_dataset_options(options) |
|
|
210 |
# Center Crops of slices from patients with lesions. Only for testing |
|
|
211 |
dataset_options.partition = {'TRAIN': 0.0, 'VAL': 0.0, 'TEST': 1.0} |
|
|
212 |
dataset_options.filterType = 'SEVEREMS' |
|
|
213 |
dataset_options.rotations = [0] |
|
|
214 |
return BRAINWEB(dataset_options) |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
def get_Brainweb_dataset_options(options): |
|
|
218 |
dataset_options = BRAINWEB.Options() |
|
|
219 |
dataset_options.description = "" |
|
|
220 |
dataset_options.debug = options['debug'] |
|
|
221 |
dataset_options.dir = options['data']['dir'] |
|
|
222 |
dataset_options.useCrops = False |
|
|
223 |
dataset_options.cropType = 'center' # Not used when useCrops is False |
|
|
224 |
dataset_options.cropWidth = options['train']['outputWidth'] |
|
|
225 |
dataset_options.cropHeight = options['train']['outputHeight'] |
|
|
226 |
dataset_options.numRandomCropsPerSlice = 5 # Not needed when doing center crops |
|
|
227 |
dataset_options.rotations = [0] |
|
|
228 |
dataset_options.partition = {'TRAIN': 0.7, 'VAL': 0.3, 'TEST': 0.0} |
|
|
229 |
dataset_options.sliceResolution = [options['train']['outputHeight'], options['train']['outputWidth']] |
|
|
230 |
dataset_options.cache = True |
|
|
231 |
dataset_options.numSamples = -1 |
|
|
232 |
dataset_options.addInstanceNoise = False |
|
|
233 |
dataset_options.axis = 'axial' |
|
|
234 |
dataset_options.filterType = 'NORMAL' |
|
|
235 |
dataset_options.filterProtocol = 'T2' |
|
|
236 |
dataset_options.normalizationMethod = 'scaling' |
|
|
237 |
dataset_options.skullRemoval = True |
|
|
238 |
dataset_options.sliceStart = options['sliceStart'] |
|
|
239 |
dataset_options.sliceEnd = options['sliceEnd'] |
|
|
240 |
dataset_options.backgroundRemoval = True |
|
|
241 |
dataset_options.registerTo = None |
|
|
242 |
return dataset_options |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
def get_config(trainer, options, optimizer, intermediateResolutions, dropout_rate, dataset): |
|
|
246 |
config = trainer.Config() |
|
|
247 |
config.dataset = type(dataset).__name__ |
|
|
248 |
config.description = '' |
|
|
249 |
config.numChannels = dataset.num_channels |
|
|
250 |
config.batchsize = options['train']['batchsize'] |
|
|
251 |
config.checkpointDir = options['train']['checkpointDir'] |
|
|
252 |
config.snapShotAfter = options['train']['snapshotAfter'] |
|
|
253 |
config.sampleDir = options['train']['samplesDir'] |
|
|
254 |
config.learningrate = options['train']['learningrate'] |
|
|
255 |
config.numEpochs = options['train']['numEpochs'] |
|
|
256 |
config.zDim = options['train']['zDim'] |
|
|
257 |
config.beta1 = 0.5 |
|
|
258 |
config.outputHeight = options['train']['outputHeight'] |
|
|
259 |
config.outputWidth = options['train']['outputWidth'] |
|
|
260 |
config.useTensorboard = options['train']['useTensorboard'] |
|
|
261 |
config.useMatplotlib = options['train']['useMatplotlib'] |
|
|
262 |
config.tensorboardPort = options['train']['tensorboardPort'] |
|
|
263 |
config.debugGradients = options['debug'] |
|
|
264 |
config.optimizer = optimizer |
|
|
265 |
config.intermediateResolutions = intermediateResolutions |
|
|
266 |
config.weightRegularization = 0.0 |
|
|
267 |
config.dropout_rate = dropout_rate |
|
|
268 |
config.dropout = False |
|
|
269 |
config.l1_weight = 1.0 |
|
|
270 |
config.options = options |
|
|
271 |
return config |