a b/tests/test_model.py
1
from __future__ import division, print_function
2
3
import unittest
4
5
from keras.layers import Input
6
from keras import backend as K
7
8
from rvseg.models import convunet
9
from rvseg.models import unet
10
11
class TestModel(unittest.TestCase):
12
    def test_downsampling(self):
13
        inputs = Input(shape=(28, 28, 1))
14
        filters = 16
15
        padding = 'valid'
16
        x, y = convunet.downsampling_block(inputs, filters, padding)
17
        self.assertTupleEqual(K.int_shape(x), (None, 12, 12, filters))
18
        self.assertTupleEqual(K.int_shape(y), (None, 24, 24, filters))
19
20
        padding = 'same'
21
        x, y = convunet.downsampling_block(inputs, filters, padding)
22
        self.assertTupleEqual(K.int_shape(x), (None, 14, 14, filters))
23
        self.assertTupleEqual(K.int_shape(y), (None, 28, 28, filters))
24
25
    def test_downsampling_error(self):
26
        # downsampling should fail on odd-integer dimension images
27
        inputs = Input(shape=(29, 29, 1))
28
        filters = 16
29
        with self.assertRaises(AssertionError):
30
            convunet.downsampling_block(inputs, filters, padding='valid')
31
        with self.assertRaises(AssertionError):
32
            convunet.downsampling_block(inputs, filters, padding='same')
33
34
    def test_upsampling(self):
35
        # concatenation without cropping
36
        filters = 16
37
        inputs = Input(shape=(14, 14, 2*filters))
38
        skip = Input(shape=(28, 28, filters))
39
        padding = 'valid'
40
        x = convunet.upsampling_block(inputs, skip, filters, padding)
41
        self.assertTupleEqual(K.int_shape(x), (None, 24, 24, filters))
42
43
        # ((4,4), (4,4)) cropping
44
        filters = 15
45
        inputs = Input(shape=(10, 10, 2*filters))
46
        skip = Input(shape=(28, 28, filters))
47
        padding = 'valid'
48
        x = convunet.upsampling_block(inputs, skip, filters, padding)
49
        self.assertTupleEqual(K.int_shape(x), (None, 16, 16, filters))
50
51
        # odd-integer input size
52
        filters = 4
53
        inputs = Input(shape=(11, 11, 2*filters))
54
        skip = Input(shape=(28, 28, filters))
55
        padding = 'valid'
56
        x = convunet.upsampling_block(inputs, skip, filters, padding)
57
        self.assertTupleEqual(K.int_shape(x), (None, 18, 18, filters))
58
59
        # test odd-integer cropping
60
        filters = 5
61
        inputs = Input(shape=(11, 11, 2*filters))
62
        skip = Input(shape=(27, 27, filters))
63
        padding = 'valid'
64
        x = convunet.upsampling_block(inputs, skip, filters, padding)
65
        self.assertTupleEqual(K.int_shape(x), (None, 18, 18, filters))
66
67
        # test same padding
68
        filters = 5
69
        inputs = Input(shape=(11, 11, 2*filters))
70
        skip = Input(shape=(27, 27, filters))
71
        padding = 'same'
72
        x = convunet.upsampling_block(inputs, skip, filters, padding)
73
        self.assertTupleEqual(K.int_shape(x), (None, 22, 22, filters))
74
75
    def test_upsampling_error(self):
76
        filters = 2
77
        inputs = Input(shape=(11, 11, 2*filters))
78
        padding = 'valid'
79
        with self.assertRaises(AssertionError):
80
            skip = Input(shape=(21, 22, filters))
81
            x = convunet.upsampling_block(inputs, skip, filters, padding)
82
        with self.assertRaises(AssertionError):
83
            skip = Input(shape=(22, 21, filters))
84
            x = convunet.upsampling_block(inputs, skip, filters, padding)
85
86
    def test_unet(self):
87
        # classic u-net architecture from
88
        #   "U-Net: Convolutional Networks for Biomedical Image Segmentation"
89
        #   O. Ronneberger, P. Fischer, T. Brox (2015)
90
        height, width, channels = 572, 572, 1
91
        features = 64
92
        depth = 4
93
        classes = 2
94
        temperature = 1.0
95
        padding = 'valid'
96
        m = unet(height, width, channels, classes, features, depth,
97
                 temperature, padding)
98
        self.assertEqual(len(m.layers), 56)
99
100
        # input/output dimensions
101
        self.assertTupleEqual(K.int_shape(m.input), (None, 572, 572, 1))
102
        self.assertTupleEqual(K.int_shape(m.output), (None, 388, 388, 2))
103
104
        # layers
105
        layer_output_dims = [
106
            (None, 572, 572, 1), # input
107
            (None, 570, 570, 64),
108
            (None, 570, 570, 64),
109
            (None, 568, 568, 64), # skip 1
110
            (None, 568, 568, 64),
111
            (None, 284, 284, 64), # max pool 2x2
112
            (None, 282, 282, 128),
113
            (None, 282, 282, 128),
114
            (None, 280, 280, 128), # skip 2
115
            (None, 280, 280, 128),
116
            (None, 140, 140, 128), # max pool 2x2
117
            (None, 138, 138, 256),
118
            (None, 138, 138, 256),
119
            (None, 136, 136, 256), # skip 3
120
            (None, 136, 136, 256),
121
            (None, 68, 68, 256), # max pool 2x2
122
            (None, 66, 66, 512),
123
            (None, 66, 66, 512),
124
            (None, 64, 64, 512), # skip 4
125
            (None, 64, 64, 512),
126
            (None, 32, 32, 512), # max pool 2x2
127
            (None, 30, 30, 1024),
128
            (None, 30, 30, 1024),
129
            (None, 28, 28, 1024),
130
            (None, 28, 28, 1024),
131
            (None, 56, 56, 512), # up-conv 2x2
132
            (None, 56, 56, 512), # cropping of skip 4
133
            (None, 56, 56, 1024), # concat
134
            (None, 54, 54, 512),
135
            (None, 54, 54, 512),
136
            (None, 52, 52, 512),
137
            (None, 52, 52, 512),
138
            (None, 104, 104, 256), # up-conv 2x2
139
            (None, 104, 104, 256), # cropping of skip 3
140
            (None, 104, 104, 512), # concat
141
            (None, 102, 102, 256),
142
            (None, 102, 102, 256),
143
            (None, 100, 100, 256),
144
            (None, 100, 100, 256),
145
            (None, 200, 200, 128), # up-conv 2x2
146
            (None, 200, 200, 128), # cropping of skip 2
147
            (None, 200, 200, 256), # concat
148
            (None, 198, 198, 128),
149
            (None, 198, 198, 128),
150
            (None, 196, 196, 128),
151
            (None, 196, 196, 128),
152
            (None, 392, 392, 64), # up-conv 2x2
153
            (None, 392, 392, 64), # cropping of skip 1
154
            (None, 392, 392, 128), # concat
155
            (None, 390, 390, 64),
156
            (None, 390, 390, 64),
157
            (None, 388, 388, 64),
158
            (None, 388, 388, 64),
159
            (None, 388, 388, 2), # output segmentation map
160
            (None, 388, 388, 2),
161
            (None, 388, 388, 2),
162
        ]
163
        for layer, shape in zip(m.layers, layer_output_dims):
164
            self.assertTupleEqual(layer.output_shape, shape)
165
166
    def check_layer_dims(self, model):
167
        # if we include only one of batch normalization or dropout,
168
        # then the shape of the network should be the same.
169
        layer_output_dims = [
170
            (None, 10, 10, 1), # input
171
            (None, 10, 10, 4), # conv2D
172
            (None, 10, 10, 4), # batchnorm | reLU
173
            (None, 10, 10, 4), # reLU      | dropout
174
            (None, 10, 10, 4), # conv2D
175
            (None, 10, 10, 4), # batchnorm | reLU
176
            (None, 10, 10, 4), # reLU      | dropout
177
            (None, 5, 5, 4),   # max pool 2x2
178
            (None, 5, 5, 8),   # conv2D
179
            (None, 5, 5, 8),   # batchnorm | reLU
180
            (None, 5, 5, 8),   # reLU      | dropout
181
            (None, 5, 5, 8),   # conv2D
182
            (None, 5, 5, 8),   # batchnorm | reLU
183
            (None, 5, 5, 8),   # reLU      | dropout
184
            (None, 10, 10, 4), # up-conv 2x2
185
            (None, 10, 10, 8), # concat
186
            (None, 10, 10, 4), # conv2D
187
            (None, 10, 10, 4), # batchnorm | reLU
188
            (None, 10, 10, 4), # reLU      | dropout
189
            (None, 10, 10, 4), # conv2D
190
            (None, 10, 10, 4), # batchnorm | reLU
191
            (None, 10, 10, 4), # reLU      | dropout
192
            (None, 10, 10, 2), # output segmentation map
193
            (None, 10, 10, 2), # (temperature)
194
            (None, 10, 10, 2), # softmax
195
        ]
196
        for layer, shape in zip(model.layers, layer_output_dims):
197
            self.assertTupleEqual(layer.output_shape, shape)
198
199
    def test_batchnorm(self):
200
        # only batch norm, no dropout
201
        height, width, channels = 10, 10, 1
202
        features = 4
203
        depth = 1
204
        classes = 2
205
        temperature = 1.0
206
        padding = 'same'
207
        batchnorm = True
208
        dropout = False
209
        m = unet(height, width, channels, classes, features, depth,
210
                 temperature, padding, batchnorm, dropout)
211
        self.assertEqual(len(m.layers), 25)
212
213
        # input/output dimensions
214
        self.assertTupleEqual(K.int_shape(m.input), (None, 10, 10, 1))
215
        self.assertTupleEqual(K.int_shape(m.output), (None, 10, 10, 2))
216
217
        self.check_layer_dims(m)
218
219
    def test_dropout(self):
220
        # only dropout, no batch norm
221
        height, width, channels = 10, 10, 1
222
        features = 4
223
        depth = 1
224
        classes = 2
225
        temperature = 1.0
226
        padding = 'same'
227
        batchnorm = False
228
        dropout = True
229
        m = unet(height, width, channels, classes, features, depth,
230
                 temperature, padding, batchnorm, dropout)
231
        self.assertEqual(len(m.layers), 25)
232
233
        # input/output dimensions
234
        self.assertTupleEqual(K.int_shape(m.input), (None, 10, 10, 1))
235
        self.assertTupleEqual(K.int_shape(m.output), (None, 10, 10, 2))
236
237
        self.check_layer_dims(m)