|
a |
|
b/configs_seg_patch/luna_p5.py |
|
|
1 |
import numpy as np |
|
|
2 |
import data_transforms |
|
|
3 |
import data_iterators |
|
|
4 |
import pathfinder |
|
|
5 |
import lasagne as nn |
|
|
6 |
from collections import namedtuple |
|
|
7 |
from functools import partial |
|
|
8 |
import lasagne.layers.dnn as dnn |
|
|
9 |
import theano.tensor as T |
|
|
10 |
import utils |
|
|
11 |
|
|
|
12 |
restart_from_save = None |
|
|
13 |
rng = np.random.RandomState(42) |
|
|
14 |
|
|
|
15 |
# transformations |
|
|
16 |
p_transform = {'patch_size': (64, 64, 64), |
|
|
17 |
'mm_patch_size': (64, 64, 64), |
|
|
18 |
'pixel_spacing': (1., 1., 1.) |
|
|
19 |
} |
|
|
20 |
p_transform_augment = { |
|
|
21 |
'translation_range_z': [-16, 16], |
|
|
22 |
'translation_range_y': [-16, 16], |
|
|
23 |
'translation_range_x': [-16, 16], |
|
|
24 |
'rotation_range_z': [-180, 180], |
|
|
25 |
'rotation_range_y': [-180, 180], |
|
|
26 |
'rotation_range_x': [-180, 180] |
|
|
27 |
} |
|
|
28 |
|
|
|
29 |
zmuv_mean, zmuv_std = 0.36, 0.31 |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
# data preparation function |
|
|
33 |
def data_prep_function(data, patch_center, luna_annotations, pixel_spacing, luna_origin, p_transform, |
|
|
34 |
p_transform_augment, **kwargs): |
|
|
35 |
x, patch_annotation_tf, annotations_tf = data_transforms.transform_patch3d(data=data, |
|
|
36 |
luna_annotations=luna_annotations, |
|
|
37 |
patch_center=patch_center, |
|
|
38 |
p_transform=p_transform, |
|
|
39 |
p_transform_augment=p_transform_augment, |
|
|
40 |
pixel_spacing=pixel_spacing, |
|
|
41 |
luna_origin=luna_origin) |
|
|
42 |
x = data_transforms.hu2normHU(x) |
|
|
43 |
x = data_transforms.zmuv(x, zmuv_mean, zmuv_std) |
|
|
44 |
y = data_transforms.make_3d_mask_from_annotations(img_shape=x.shape, annotations=annotations_tf, shape='sphere') |
|
|
45 |
return x, y |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
data_prep_function_train = partial(data_prep_function, p_transform_augment=p_transform_augment, p_transform=p_transform) |
|
|
49 |
data_prep_function_valid = partial(data_prep_function, p_transform_augment=None, p_transform=p_transform) |
|
|
50 |
|
|
|
51 |
# data iterators |
|
|
52 |
batch_size = 4 |
|
|
53 |
nbatches_chunk = 8 |
|
|
54 |
chunk_size = batch_size * nbatches_chunk |
|
|
55 |
|
|
|
56 |
train_valid_ids = utils.load_pkl(pathfinder.LUNA_VALIDATION_SPLIT_PATH) |
|
|
57 |
train_pids, valid_pids = train_valid_ids['train'], train_valid_ids['valid'] |
|
|
58 |
|
|
|
59 |
train_data_iterator = data_iterators.PatchPositiveLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH, |
|
|
60 |
batch_size=chunk_size, |
|
|
61 |
transform_params=p_transform, |
|
|
62 |
data_prep_fun=data_prep_function_train, |
|
|
63 |
rng=rng, |
|
|
64 |
patient_ids=train_pids, |
|
|
65 |
full_batch=True, random=True, infinite=True) |
|
|
66 |
|
|
|
67 |
valid_data_iterator = data_iterators.ValidPatchPositiveLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH, |
|
|
68 |
transform_params=p_transform, |
|
|
69 |
data_prep_fun=data_prep_function_valid, |
|
|
70 |
patient_ids=valid_pids) |
|
|
71 |
if zmuv_mean is None or zmuv_std is None: |
|
|
72 |
print 'estimating ZMUV parameters' |
|
|
73 |
x_big = None |
|
|
74 |
for i, (x, _, _) in zip(xrange(4), train_data_iterator.generate()): |
|
|
75 |
print i |
|
|
76 |
x_big = x if x_big is None else np.concatenate((x_big, x), axis=0) |
|
|
77 |
zmuv_mean = x_big.mean() |
|
|
78 |
zmuv_std = x_big.std() |
|
|
79 |
print 'mean:', zmuv_mean |
|
|
80 |
print 'std:', zmuv_std |
|
|
81 |
|
|
|
82 |
nchunks_per_epoch = train_data_iterator.nsamples / chunk_size |
|
|
83 |
max_nchunks = nchunks_per_epoch * 30 |
|
|
84 |
|
|
|
85 |
validate_every = int(2. * nchunks_per_epoch) |
|
|
86 |
save_every = int(0.5 * nchunks_per_epoch) |
|
|
87 |
|
|
|
88 |
learning_rate_schedule = { |
|
|
89 |
0: 1e-5, |
|
|
90 |
int(max_nchunks * 0.4): 5e-6, |
|
|
91 |
int(max_nchunks * 0.5): 2e-6, |
|
|
92 |
int(max_nchunks * 0.8): 1e-6, |
|
|
93 |
int(max_nchunks * 0.9): 5e-7 |
|
|
94 |
} |
|
|
95 |
|
|
|
96 |
# model |
|
|
97 |
conv3d = partial(dnn.Conv3DDNNLayer, |
|
|
98 |
filter_size=3, |
|
|
99 |
pad='valid', |
|
|
100 |
W=nn.init.Orthogonal('relu'), |
|
|
101 |
b=nn.init.Constant(0.0), |
|
|
102 |
nonlinearity=nn.nonlinearities.identity) |
|
|
103 |
|
|
|
104 |
max_pool3d = partial(dnn.MaxPool3DDNNLayer, |
|
|
105 |
pool_size=2) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
def conv_prelu_layer(l_in, n_filters): |
|
|
109 |
l = conv3d(l_in, n_filters) |
|
|
110 |
l = nn.layers.ParametricRectifierLayer(l) |
|
|
111 |
return l |
|
|
112 |
|
|
|
113 |
|
|
|
114 |
def build_model(): |
|
|
115 |
l_in = nn.layers.InputLayer((None, 1,) + p_transform['patch_size']) |
|
|
116 |
l_target = nn.layers.InputLayer((None, 1,) + p_transform['patch_size']) |
|
|
117 |
|
|
|
118 |
net = {} |
|
|
119 |
base_n_filters = 128 |
|
|
120 |
net['contr_1_1'] = conv_prelu_layer(l_in, base_n_filters) |
|
|
121 |
net['contr_1_2'] = conv_prelu_layer(net['contr_1_1'], base_n_filters) |
|
|
122 |
net['contr_1_3'] = conv_prelu_layer(net['contr_1_2'], base_n_filters) |
|
|
123 |
net['pool1'] = max_pool3d(net['contr_1_3']) |
|
|
124 |
|
|
|
125 |
net['encode_1'] = conv_prelu_layer(net['pool1'], base_n_filters) |
|
|
126 |
net['encode_2'] = conv_prelu_layer(net['encode_1'], base_n_filters) |
|
|
127 |
net['encode_3'] = conv_prelu_layer(net['encode_2'], base_n_filters) |
|
|
128 |
net['encode_4'] = conv_prelu_layer(net['encode_3'], base_n_filters) |
|
|
129 |
|
|
|
130 |
net['upscale1'] = nn.layers.Upscale3DLayer(net['encode_4'], 2) |
|
|
131 |
net['concat1'] = nn.layers.ConcatLayer([net['upscale1'], net['contr_1_3']], |
|
|
132 |
cropping=(None, None, "center", "center", "center")) |
|
|
133 |
|
|
|
134 |
net['dropout_1'] = nn.layers.DropoutLayer(net['concat1']) |
|
|
135 |
|
|
|
136 |
net['expand_1_1'] = conv_prelu_layer(net['dropout_1'], 2 * base_n_filters) |
|
|
137 |
net['expand_1_2'] = conv_prelu_layer(net['expand_1_1'], base_n_filters) |
|
|
138 |
net['expand_1_3'] = conv_prelu_layer(net['expand_1_2'], base_n_filters) |
|
|
139 |
net['expand_1_4'] = conv_prelu_layer(net['expand_1_3'], base_n_filters / 2) |
|
|
140 |
net['expand_1_5'] = conv_prelu_layer(net['expand_1_4'], base_n_filters / 2) |
|
|
141 |
|
|
|
142 |
l_out = dnn.Conv3DDNNLayer(net['expand_1_5'], num_filters=1, |
|
|
143 |
filter_size=1, |
|
|
144 |
nonlinearity=nn.nonlinearities.sigmoid) |
|
|
145 |
|
|
|
146 |
return namedtuple('Model', ['l_in', 'l_out', 'l_target'])(l_in, l_out, l_target) |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
def build_objective(model, deterministic=False, epsilon=1e-12): |
|
|
150 |
network_predictions = nn.layers.get_output(model.l_out) |
|
|
151 |
target_values = nn.layers.get_output(model.l_target) |
|
|
152 |
network_predictions, target_values = nn.layers.merge.autocrop([network_predictions, target_values], |
|
|
153 |
[None, None, 'center', 'center', 'center']) |
|
|
154 |
y_true_f = target_values |
|
|
155 |
y_pred_f = network_predictions |
|
|
156 |
|
|
|
157 |
intersection = T.sum(y_true_f * y_pred_f) |
|
|
158 |
dice = (2 * intersection + epsilon) / (T.sum(y_true_f) + T.sum(y_pred_f) + epsilon) |
|
|
159 |
return -1. * dice |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def build_updates(train_loss, model, learning_rate): |
|
|
163 |
updates = nn.updates.adam(train_loss, nn.layers.get_all_params(model.l_out), learning_rate) |
|
|
164 |
return updates |