|
a |
|
b/tests/test_model.py |
|
|
1 |
from __future__ import division, print_function |
|
|
2 |
|
|
|
3 |
import unittest |
|
|
4 |
|
|
|
5 |
from keras.layers import Input |
|
|
6 |
from keras import backend as K |
|
|
7 |
|
|
|
8 |
from rvseg.models import convunet |
|
|
9 |
from rvseg.models import unet |
|
|
10 |
|
|
|
11 |
class TestModel(unittest.TestCase): |
|
|
12 |
def test_downsampling(self): |
|
|
13 |
inputs = Input(shape=(28, 28, 1)) |
|
|
14 |
filters = 16 |
|
|
15 |
padding = 'valid' |
|
|
16 |
x, y = convunet.downsampling_block(inputs, filters, padding) |
|
|
17 |
self.assertTupleEqual(K.int_shape(x), (None, 12, 12, filters)) |
|
|
18 |
self.assertTupleEqual(K.int_shape(y), (None, 24, 24, filters)) |
|
|
19 |
|
|
|
20 |
padding = 'same' |
|
|
21 |
x, y = convunet.downsampling_block(inputs, filters, padding) |
|
|
22 |
self.assertTupleEqual(K.int_shape(x), (None, 14, 14, filters)) |
|
|
23 |
self.assertTupleEqual(K.int_shape(y), (None, 28, 28, filters)) |
|
|
24 |
|
|
|
25 |
def test_downsampling_error(self): |
|
|
26 |
# downsampling should fail on odd-integer dimension images |
|
|
27 |
inputs = Input(shape=(29, 29, 1)) |
|
|
28 |
filters = 16 |
|
|
29 |
with self.assertRaises(AssertionError): |
|
|
30 |
convunet.downsampling_block(inputs, filters, padding='valid') |
|
|
31 |
with self.assertRaises(AssertionError): |
|
|
32 |
convunet.downsampling_block(inputs, filters, padding='same') |
|
|
33 |
|
|
|
34 |
def test_upsampling(self): |
|
|
35 |
# concatenation without cropping |
|
|
36 |
filters = 16 |
|
|
37 |
inputs = Input(shape=(14, 14, 2*filters)) |
|
|
38 |
skip = Input(shape=(28, 28, filters)) |
|
|
39 |
padding = 'valid' |
|
|
40 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
41 |
self.assertTupleEqual(K.int_shape(x), (None, 24, 24, filters)) |
|
|
42 |
|
|
|
43 |
# ((4,4), (4,4)) cropping |
|
|
44 |
filters = 15 |
|
|
45 |
inputs = Input(shape=(10, 10, 2*filters)) |
|
|
46 |
skip = Input(shape=(28, 28, filters)) |
|
|
47 |
padding = 'valid' |
|
|
48 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
49 |
self.assertTupleEqual(K.int_shape(x), (None, 16, 16, filters)) |
|
|
50 |
|
|
|
51 |
# odd-integer input size |
|
|
52 |
filters = 4 |
|
|
53 |
inputs = Input(shape=(11, 11, 2*filters)) |
|
|
54 |
skip = Input(shape=(28, 28, filters)) |
|
|
55 |
padding = 'valid' |
|
|
56 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
57 |
self.assertTupleEqual(K.int_shape(x), (None, 18, 18, filters)) |
|
|
58 |
|
|
|
59 |
# test odd-integer cropping |
|
|
60 |
filters = 5 |
|
|
61 |
inputs = Input(shape=(11, 11, 2*filters)) |
|
|
62 |
skip = Input(shape=(27, 27, filters)) |
|
|
63 |
padding = 'valid' |
|
|
64 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
65 |
self.assertTupleEqual(K.int_shape(x), (None, 18, 18, filters)) |
|
|
66 |
|
|
|
67 |
# test same padding |
|
|
68 |
filters = 5 |
|
|
69 |
inputs = Input(shape=(11, 11, 2*filters)) |
|
|
70 |
skip = Input(shape=(27, 27, filters)) |
|
|
71 |
padding = 'same' |
|
|
72 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
73 |
self.assertTupleEqual(K.int_shape(x), (None, 22, 22, filters)) |
|
|
74 |
|
|
|
75 |
def test_upsampling_error(self): |
|
|
76 |
filters = 2 |
|
|
77 |
inputs = Input(shape=(11, 11, 2*filters)) |
|
|
78 |
padding = 'valid' |
|
|
79 |
with self.assertRaises(AssertionError): |
|
|
80 |
skip = Input(shape=(21, 22, filters)) |
|
|
81 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
82 |
with self.assertRaises(AssertionError): |
|
|
83 |
skip = Input(shape=(22, 21, filters)) |
|
|
84 |
x = convunet.upsampling_block(inputs, skip, filters, padding) |
|
|
85 |
|
|
|
86 |
def test_unet(self): |
|
|
87 |
# classic u-net architecture from |
|
|
88 |
# "U-Net: Convolutional Networks for Biomedical Image Segmentation" |
|
|
89 |
# O. Ronneberger, P. Fischer, T. Brox (2015) |
|
|
90 |
height, width, channels = 572, 572, 1 |
|
|
91 |
features = 64 |
|
|
92 |
depth = 4 |
|
|
93 |
classes = 2 |
|
|
94 |
temperature = 1.0 |
|
|
95 |
padding = 'valid' |
|
|
96 |
m = unet(height, width, channels, classes, features, depth, |
|
|
97 |
temperature, padding) |
|
|
98 |
self.assertEqual(len(m.layers), 56) |
|
|
99 |
|
|
|
100 |
# input/output dimensions |
|
|
101 |
self.assertTupleEqual(K.int_shape(m.input), (None, 572, 572, 1)) |
|
|
102 |
self.assertTupleEqual(K.int_shape(m.output), (None, 388, 388, 2)) |
|
|
103 |
|
|
|
104 |
# layers |
|
|
105 |
layer_output_dims = [ |
|
|
106 |
(None, 572, 572, 1), # input |
|
|
107 |
(None, 570, 570, 64), |
|
|
108 |
(None, 570, 570, 64), |
|
|
109 |
(None, 568, 568, 64), # skip 1 |
|
|
110 |
(None, 568, 568, 64), |
|
|
111 |
(None, 284, 284, 64), # max pool 2x2 |
|
|
112 |
(None, 282, 282, 128), |
|
|
113 |
(None, 282, 282, 128), |
|
|
114 |
(None, 280, 280, 128), # skip 2 |
|
|
115 |
(None, 280, 280, 128), |
|
|
116 |
(None, 140, 140, 128), # max pool 2x2 |
|
|
117 |
(None, 138, 138, 256), |
|
|
118 |
(None, 138, 138, 256), |
|
|
119 |
(None, 136, 136, 256), # skip 3 |
|
|
120 |
(None, 136, 136, 256), |
|
|
121 |
(None, 68, 68, 256), # max pool 2x2 |
|
|
122 |
(None, 66, 66, 512), |
|
|
123 |
(None, 66, 66, 512), |
|
|
124 |
(None, 64, 64, 512), # skip 4 |
|
|
125 |
(None, 64, 64, 512), |
|
|
126 |
(None, 32, 32, 512), # max pool 2x2 |
|
|
127 |
(None, 30, 30, 1024), |
|
|
128 |
(None, 30, 30, 1024), |
|
|
129 |
(None, 28, 28, 1024), |
|
|
130 |
(None, 28, 28, 1024), |
|
|
131 |
(None, 56, 56, 512), # up-conv 2x2 |
|
|
132 |
(None, 56, 56, 512), # cropping of skip 4 |
|
|
133 |
(None, 56, 56, 1024), # concat |
|
|
134 |
(None, 54, 54, 512), |
|
|
135 |
(None, 54, 54, 512), |
|
|
136 |
(None, 52, 52, 512), |
|
|
137 |
(None, 52, 52, 512), |
|
|
138 |
(None, 104, 104, 256), # up-conv 2x2 |
|
|
139 |
(None, 104, 104, 256), # cropping of skip 3 |
|
|
140 |
(None, 104, 104, 512), # concat |
|
|
141 |
(None, 102, 102, 256), |
|
|
142 |
(None, 102, 102, 256), |
|
|
143 |
(None, 100, 100, 256), |
|
|
144 |
(None, 100, 100, 256), |
|
|
145 |
(None, 200, 200, 128), # up-conv 2x2 |
|
|
146 |
(None, 200, 200, 128), # cropping of skip 2 |
|
|
147 |
(None, 200, 200, 256), # concat |
|
|
148 |
(None, 198, 198, 128), |
|
|
149 |
(None, 198, 198, 128), |
|
|
150 |
(None, 196, 196, 128), |
|
|
151 |
(None, 196, 196, 128), |
|
|
152 |
(None, 392, 392, 64), # up-conv 2x2 |
|
|
153 |
(None, 392, 392, 64), # cropping of skip 1 |
|
|
154 |
(None, 392, 392, 128), # concat |
|
|
155 |
(None, 390, 390, 64), |
|
|
156 |
(None, 390, 390, 64), |
|
|
157 |
(None, 388, 388, 64), |
|
|
158 |
(None, 388, 388, 64), |
|
|
159 |
(None, 388, 388, 2), # output segmentation map |
|
|
160 |
(None, 388, 388, 2), |
|
|
161 |
(None, 388, 388, 2), |
|
|
162 |
] |
|
|
163 |
for layer, shape in zip(m.layers, layer_output_dims): |
|
|
164 |
self.assertTupleEqual(layer.output_shape, shape) |
|
|
165 |
|
|
|
166 |
def check_layer_dims(self, model): |
|
|
167 |
# if we include only one of batch normalization or dropout, |
|
|
168 |
# then the shape of the network should be the same. |
|
|
169 |
layer_output_dims = [ |
|
|
170 |
(None, 10, 10, 1), # input |
|
|
171 |
(None, 10, 10, 4), # conv2D |
|
|
172 |
(None, 10, 10, 4), # batchnorm | reLU |
|
|
173 |
(None, 10, 10, 4), # reLU | dropout |
|
|
174 |
(None, 10, 10, 4), # conv2D |
|
|
175 |
(None, 10, 10, 4), # batchnorm | reLU |
|
|
176 |
(None, 10, 10, 4), # reLU | dropout |
|
|
177 |
(None, 5, 5, 4), # max pool 2x2 |
|
|
178 |
(None, 5, 5, 8), # conv2D |
|
|
179 |
(None, 5, 5, 8), # batchnorm | reLU |
|
|
180 |
(None, 5, 5, 8), # reLU | dropout |
|
|
181 |
(None, 5, 5, 8), # conv2D |
|
|
182 |
(None, 5, 5, 8), # batchnorm | reLU |
|
|
183 |
(None, 5, 5, 8), # reLU | dropout |
|
|
184 |
(None, 10, 10, 4), # up-conv 2x2 |
|
|
185 |
(None, 10, 10, 8), # concat |
|
|
186 |
(None, 10, 10, 4), # conv2D |
|
|
187 |
(None, 10, 10, 4), # batchnorm | reLU |
|
|
188 |
(None, 10, 10, 4), # reLU | dropout |
|
|
189 |
(None, 10, 10, 4), # conv2D |
|
|
190 |
(None, 10, 10, 4), # batchnorm | reLU |
|
|
191 |
(None, 10, 10, 4), # reLU | dropout |
|
|
192 |
(None, 10, 10, 2), # output segmentation map |
|
|
193 |
(None, 10, 10, 2), # (temperature) |
|
|
194 |
(None, 10, 10, 2), # softmax |
|
|
195 |
] |
|
|
196 |
for layer, shape in zip(model.layers, layer_output_dims): |
|
|
197 |
self.assertTupleEqual(layer.output_shape, shape) |
|
|
198 |
|
|
|
199 |
def test_batchnorm(self): |
|
|
200 |
# only batch norm, no dropout |
|
|
201 |
height, width, channels = 10, 10, 1 |
|
|
202 |
features = 4 |
|
|
203 |
depth = 1 |
|
|
204 |
classes = 2 |
|
|
205 |
temperature = 1.0 |
|
|
206 |
padding = 'same' |
|
|
207 |
batchnorm = True |
|
|
208 |
dropout = False |
|
|
209 |
m = unet(height, width, channels, classes, features, depth, |
|
|
210 |
temperature, padding, batchnorm, dropout) |
|
|
211 |
self.assertEqual(len(m.layers), 25) |
|
|
212 |
|
|
|
213 |
# input/output dimensions |
|
|
214 |
self.assertTupleEqual(K.int_shape(m.input), (None, 10, 10, 1)) |
|
|
215 |
self.assertTupleEqual(K.int_shape(m.output), (None, 10, 10, 2)) |
|
|
216 |
|
|
|
217 |
self.check_layer_dims(m) |
|
|
218 |
|
|
|
219 |
def test_dropout(self): |
|
|
220 |
# only dropout, no batch norm |
|
|
221 |
height, width, channels = 10, 10, 1 |
|
|
222 |
features = 4 |
|
|
223 |
depth = 1 |
|
|
224 |
classes = 2 |
|
|
225 |
temperature = 1.0 |
|
|
226 |
padding = 'same' |
|
|
227 |
batchnorm = False |
|
|
228 |
dropout = True |
|
|
229 |
m = unet(height, width, channels, classes, features, depth, |
|
|
230 |
temperature, padding, batchnorm, dropout) |
|
|
231 |
self.assertEqual(len(m.layers), 25) |
|
|
232 |
|
|
|
233 |
# input/output dimensions |
|
|
234 |
self.assertTupleEqual(K.int_shape(m.input), (None, 10, 10, 1)) |
|
|
235 |
self.assertTupleEqual(K.int_shape(m.output), (None, 10, 10, 2)) |
|
|
236 |
|
|
|
237 |
self.check_layer_dims(m) |