a b/configs_fpred_patch/luna_c3.py
1
import numpy as np
2
import data_transforms
3
import data_iterators
4
import pathfinder
5
import lasagne as nn
6
from collections import namedtuple
7
from functools import partial
8
import lasagne.layers.dnn as dnn
9
import lasagne
10
import theano.tensor as T
11
import utils
12
13
restart_from_save = None
14
rng = np.random.RandomState(42)
15
16
# transformations
17
p_transform = {'patch_size': (48, 48, 48),
18
               'mm_patch_size': (48, 48, 48),
19
               'pixel_spacing': (1., 1., 1.)
20
               }
21
p_transform_augment = {
22
    'translation_range_z': [-3, 3],
23
    'translation_range_y': [-3, 3],
24
    'translation_range_x': [-3, 3],
25
    'rotation_range_z': [-180, 180],
26
    'rotation_range_y': [-180, 180],
27
    'rotation_range_x': [-180, 180]
28
}
29
30
31
# data preparation function
32
def data_prep_function(data, patch_center, pixel_spacing, luna_origin, p_transform,
33
                       p_transform_augment, world_coord_system, **kwargs):
34
    x, patch_annotation_tf = data_transforms.transform_patch3d(data=data,
35
                                                               luna_annotations=None,
36
                                                               patch_center=patch_center,
37
                                                               p_transform=p_transform,
38
                                                               p_transform_augment=p_transform_augment,
39
                                                               pixel_spacing=pixel_spacing,
40
                                                               luna_origin=luna_origin,
41
                                                               world_coord_system=world_coord_system)
42
    x = data_transforms.pixelnormHU(x)
43
    return x
44
45
46
data_prep_function_train = partial(data_prep_function, p_transform_augment=p_transform_augment,
47
                                   p_transform=p_transform, world_coord_system=True)
48
data_prep_function_valid = partial(data_prep_function, p_transform_augment=None,
49
                                   p_transform=p_transform, world_coord_system=True)
50
51
# data iterators
52
batch_size = 16
53
nbatches_chunk = 1
54
chunk_size = batch_size * nbatches_chunk
55
56
train_valid_ids = utils.load_pkl(pathfinder.LUNA_VALIDATION_SPLIT_PATH)
57
train_pids, valid_pids = train_valid_ids['train'], train_valid_ids['valid']
58
59
train_data_iterator = data_iterators.CandidatesLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH,
60
                                                                 batch_size=chunk_size,
61
                                                                 transform_params=p_transform,
62
                                                                 data_prep_fun=data_prep_function_train,
63
                                                                 rng=rng,
64
                                                                 patient_ids=train_pids,
65
                                                                 full_batch=True, random=True, infinite=True,
66
                                                                 positive_proportion=0.5)
67
68
valid_data_iterator = data_iterators.CandidatesLunaValidDataGenerator(data_path=pathfinder.LUNA_DATA_PATH,
69
                                                                      transform_params=p_transform,
70
                                                                      data_prep_fun=data_prep_function_valid,
71
                                                                      patient_ids=valid_pids)
72
73
nchunks_per_epoch = train_data_iterator.nsamples / chunk_size
74
max_nchunks = nchunks_per_epoch * 100
75
76
validate_every = int(5. * nchunks_per_epoch)
77
save_every = int(1. * nchunks_per_epoch)
78
79
learning_rate_schedule = {
80
    0: 5e-4,
81
    int(max_nchunks * 0.5): 2e-4,
82
    int(max_nchunks * 0.6): 1e-4,
83
    int(max_nchunks * 0.7): 5e-5,
84
    int(max_nchunks * 0.8): 2e-5,
85
    int(max_nchunks * 0.9): 1e-5
86
}
87
88
# model
89
conv3d = partial(dnn.Conv3DDNNLayer,
90
                 filter_size=3,
91
                 pad='same',
92
                 W=nn.init.Orthogonal(),
93
                 nonlinearity=nn.nonlinearities.very_leaky_rectify)
94
95
max_pool3d = partial(dnn.MaxPool3DDNNLayer,
96
                     pool_size=2)
97
98
drop = lasagne.layers.DropoutLayer
99
100
dense = partial(lasagne.layers.DenseLayer,
101
                W=lasagne.init.Orthogonal(),
102
                nonlinearity=lasagne.nonlinearities.very_leaky_rectify)
103
104
105
def inrn_v2(lin):
106
    n_base_filter = 32
107
108
    l1 = conv3d(lin, n_base_filter, filter_size=1)
109
110
    l2 = conv3d(lin, n_base_filter, filter_size=1)
111
    l2 = conv3d(l2, n_base_filter, filter_size=3)
112
113
    l3 = conv3d(lin, n_base_filter, filter_size=1)
114
    l3 = conv3d(l3, n_base_filter, filter_size=3)
115
    l3 = conv3d(l3, n_base_filter, filter_size=3)
116
117
    l = lasagne.layers.ConcatLayer([l1, l2, l3])
118
119
    l = conv3d(l, lin.output_shape[1], filter_size=1)
120
121
    l = lasagne.layers.ElemwiseSumLayer([l, lin])
122
123
    l = lasagne.layers.NonlinearityLayer(l, nonlinearity=lasagne.nonlinearities.rectify)
124
125
    return l
126
127
128
def inrn_v2_red(lin):
129
    # We want to reduce our total volume /4
130
131
    den = 16
132
    nom2 = 4
133
    nom3 = 5
134
    nom4 = 7
135
136
    ins = lin.output_shape[1]
137
138
    l1 = max_pool3d(lin)
139
140
    l2 = conv3d(lin, ins // den * nom2, filter_size=3, stride=2)
141
142
    l3 = conv3d(lin, ins // den * nom2, filter_size=1)
143
    l3 = conv3d(l3, ins // den * nom3, filter_size=3, stride=2)
144
145
    l4 = conv3d(lin, ins // den * nom2, filter_size=1)
146
    l4 = conv3d(l4, ins // den * nom3, filter_size=3)
147
    l4 = conv3d(l4, ins // den * nom4, filter_size=3, stride=2)
148
149
    l = lasagne.layers.ConcatLayer([l1, l2, l3, l4])
150
151
    return l
152
153
154
def feat_red(lin):
155
    # We want to reduce the feature maps by a factor of 2
156
    ins = lin.output_shape[1]
157
    l = conv3d(lin, ins // 2, filter_size=1)
158
    return l
159
160
161
def build_model():
162
    l_in = nn.layers.InputLayer((None, 1,) + p_transform['patch_size'])
163
    l_target = nn.layers.InputLayer((None, 1))
164
165
    l = conv3d(l_in, 64)
166
    l = inrn_v2_red(l)
167
    l = inrn_v2(l)
168
    l = feat_red(l)
169
    l = inrn_v2(l)
170
171
    l = inrn_v2_red(l)
172
    l = inrn_v2(l)
173
    l = feat_red(l)
174
    l = inrn_v2(l)
175
176
    l = feat_red(l)
177
178
    l = dense(drop(l), 128)
179
180
    l_out = nn.layers.DenseLayer(l, num_units=2,
181
                                 W=nn.init.Constant(0.),
182
                                 nonlinearity=nn.nonlinearities.softmax)
183
184
    return namedtuple('Model', ['l_in', 'l_out', 'l_target'])(l_in, l_out, l_target)
185
186
187
def build_objective(model, deterministic=False, epsilon=1e-12):
188
    predictions = nn.layers.get_output(model.l_out, deterministic=deterministic)
189
    targets = T.cast(T.flatten(nn.layers.get_output(model.l_target)), 'int32')
190
    p = predictions[T.arange(predictions.shape[0]), targets]
191
    p = T.clip(p, epsilon, 1.)
192
    loss = T.mean(T.log(p))
193
    return -loss
194
195
196
def build_updates(train_loss, model, learning_rate):
197
    updates = nn.updates.adam(train_loss, nn.layers.get_all_params(model.l_out, trainable=True), learning_rate)
198
    return updates