a b/configs_seg_patch/luna_p3.py
1
import numpy as np
2
import data_transforms
3
import data_iterators
4
import pathfinder
5
import lasagne as nn
6
from collections import namedtuple
7
from functools import partial
8
import lasagne.layers.dnn as dnn
9
import theano.tensor as T
10
import utils
11
12
restart_from_save = None
13
rng = np.random.RandomState(42)
14
15
# transformations
16
p_transform = {'patch_size': (64, 64, 64),
17
               'mm_patch_size': (64, 64, 64),
18
               'pixel_spacing': (1., 1., 1.)
19
               }
20
p_transform_augment = {
21
    'translation_range_z': [-16, 16],
22
    'translation_range_y': [-16, 16],
23
    'translation_range_x': [-16, 16],
24
    'rotation_range_z': [-180, 180],
25
    'rotation_range_y': [-180, 180],
26
    'rotation_range_x': [-180, 180]
27
}
28
29
zmuv_mean, zmuv_std = None, None
30
31
32
# data preparation function
33
def data_prep_function(data, patch_center, luna_annotations, pixel_spacing, luna_origin, p_transform,
34
                       p_transform_augment, **kwargs):
35
    x, patch_annotation_tf, annotations_tf = data_transforms.transform_patch3d(data=data,
36
                                                                               luna_annotations=luna_annotations,
37
                                                                               patch_center=patch_center,
38
                                                                               p_transform=p_transform,
39
                                                                               p_transform_augment=p_transform_augment,
40
                                                                               pixel_spacing=pixel_spacing,
41
                                                                               luna_origin=luna_origin)
42
    x = data_transforms.hu2normHU(x)
43
    x = data_transforms.zmuv(x, zmuv_mean, zmuv_std)
44
    y = data_transforms.make_3d_mask_from_annotations(img_shape=x.shape, annotations=annotations_tf, shape='sphere')
45
    return x, y
46
47
48
data_prep_function_train = partial(data_prep_function, p_transform_augment=p_transform_augment, p_transform=p_transform)
49
data_prep_function_valid = partial(data_prep_function, p_transform_augment=None, p_transform=p_transform)
50
51
# data iterators
52
batch_size = 4
53
nbatches_chunk = 8
54
chunk_size = batch_size * nbatches_chunk
55
56
train_valid_ids = utils.load_pkl(pathfinder.LUNA_VALIDATION_SPLIT_PATH)
57
train_pids, valid_pids = train_valid_ids['train'], train_valid_ids['valid']
58
59
train_data_iterator = data_iterators.PatchPositiveLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH,
60
                                                                    batch_size=chunk_size,
61
                                                                    transform_params=p_transform,
62
                                                                    data_prep_fun=data_prep_function_train,
63
                                                                    rng=rng,
64
                                                                    patient_ids=train_pids,
65
                                                                    full_batch=True, random=True, infinite=True)
66
67
valid_data_iterator = data_iterators.ValidPatchPositiveLunaDataGenerator(data_path=pathfinder.LUNA_DATA_PATH,
68
                                                                         transform_params=p_transform,
69
                                                                         data_prep_fun=data_prep_function_valid,
70
                                                                         patient_ids=valid_pids)
71
72
print 'estimating ZMUV parameters'
73
x_big = None
74
for i, (x, _, _) in zip(xrange(4), train_data_iterator.generate()):
75
    x_big = x if x_big is None else np.concatenate((x_big, x), axis=0)
76
zmuv_mean = x_big.mean()
77
zmuv_std = x_big.std()
78
# assert abs(zmuv_mean - 0.35) < 0.01
79
# assert abs(zmuv_std - 0.30) < 0.01
80
print 'mean:', zmuv_mean
81
print 'std:', zmuv_std
82
83
nchunks_per_epoch = train_data_iterator.nsamples / chunk_size
84
max_nchunks = nchunks_per_epoch * 30
85
86
validate_every = int(1. * nchunks_per_epoch)
87
save_every = int(0.5 * nchunks_per_epoch)
88
89
learning_rate_schedule = {
90
    0: 1e-5,
91
    int(max_nchunks * 0.4): 5e-6,
92
    int(max_nchunks * 0.5): 3e-6,
93
    int(max_nchunks * 0.6): 2e-6,
94
    int(max_nchunks * 0.85): 1e-6,
95
    int(max_nchunks * 0.95): 5e-7
96
}
97
98
# model
99
conv3d = partial(dnn.Conv3DDNNLayer,
100
                 filter_size=3,
101
                 pad='valid',
102
                 W=nn.init.Orthogonal('relu'),
103
                 b=nn.init.Constant(0.0),
104
                 nonlinearity=nn.nonlinearities.identity)
105
106
max_pool3d = partial(dnn.MaxPool3DDNNLayer,
107
                     pool_size=2)
108
109
110
def build_model():
111
    l_in = nn.layers.InputLayer((None, 1,) + p_transform['patch_size'])
112
    l_target = nn.layers.InputLayer((None, 1,) + p_transform['patch_size'])
113
114
    net = {}
115
    base_n_filters = 64
116
    net['contr_1_1'] = conv3d(l_in, base_n_filters)
117
    net['contr_1_1'] = nn.layers.ParametricRectifierLayer(net['contr_1_1'])
118
    net['contr_1_2'] = conv3d(net['contr_1_1'], base_n_filters)
119
    net['contr_1_2'] = nn.layers.ParametricRectifierLayer(net['contr_1_2'])
120
    net['contr_1_3'] = conv3d(net['contr_1_2'], base_n_filters)
121
    net['contr_1_3'] = nn.layers.ParametricRectifierLayer(net['contr_1_3'])
122
    net['pool1'] = max_pool3d(net['contr_1_3'])
123
124
    net['encode_1'] = conv3d(net['pool1'], base_n_filters)
125
    net['encode_1'] = nn.layers.ParametricRectifierLayer(net['encode_1'])
126
    net['encode_2'] = conv3d(net['encode_1'], base_n_filters)
127
    net['encode_2'] = nn.layers.ParametricRectifierLayer(net['encode_2'])
128
    net['encode_3'] = conv3d(net['encode_2'], base_n_filters)
129
    net['encode_3'] = nn.layers.ParametricRectifierLayer(net['encode_3'])
130
    net['encode_4'] = conv3d(net['encode_3'], base_n_filters)
131
    net['encode_4'] = nn.layers.ParametricRectifierLayer(net['encode_4'])
132
133
    net['upscale1'] = nn.layers.Upscale3DLayer(net['encode_4'], 2)
134
135
    net['concat1'] = nn.layers.ConcatLayer([net['upscale1'], net['contr_1_3']],
136
                                           cropping=(None, None, "center", "center", "center"))
137
    net['expand_1_1'] = conv3d(net['concat1'], 2 * base_n_filters)
138
    net['expand_1_1'] = nn.layers.ParametricRectifierLayer(net['expand_1_1'])
139
    net['expand_1_2'] = conv3d(net['expand_1_1'], 2 * base_n_filters)
140
    net['expand_1_2'] = nn.layers.ParametricRectifierLayer(net['expand_1_2'])
141
    net['expand_1_3'] = conv3d(net['expand_1_2'], base_n_filters)
142
    net['expand_1_3'] = nn.layers.ParametricRectifierLayer(net['expand_1_3'])
143
    l_out = dnn.Conv3DDNNLayer(net['expand_1_3'], num_filters=1,
144
                               filter_size=1,
145
                               nonlinearity=nn.nonlinearities.sigmoid)
146
147
    return namedtuple('Model', ['l_in', 'l_out', 'l_target'])(l_in, l_out, l_target)
148
149
150
def build_objective(model, deterministic=False, epsilon=1e-12):
151
    network_predictions = nn.layers.get_output(model.l_out)
152
    target_values = nn.layers.get_output(model.l_target)
153
    target_values = T.clip(target_values, 1e-6, 1.)
154
    network_predictions, target_values = nn.layers.merge.autocrop([network_predictions, target_values],
155
                                                                  [None, None, 'center', 'center', 'center'])
156
    y_true_f = target_values
157
    y_pred_f = network_predictions
158
159
    intersection = T.sum(y_true_f * y_pred_f)
160
    return -1. * (2 * intersection + epsilon) / (T.sum(y_true_f) + T.sum(y_pred_f) + epsilon)
161
162
163
def build_updates(train_loss, model, learning_rate):
164
    updates = nn.updates.adam(train_loss, nn.layers.get_all_params(model.l_out), learning_rate)
165
    return updates