|
a |
|
b/test/test_model.py |
|
|
1 |
from unittest import TestCase |
|
|
2 |
|
|
|
3 |
from fetal_net.model import unet_model_3d |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class TestModel(TestCase): |
|
|
7 |
def test_batch_normalization(self): |
|
|
8 |
model = unet_model_3d(input_shape=(1, 16, 16, 16), depth=2, deconvolution=True, metrics=[], n_labels=1, |
|
|
9 |
batch_normalization=True) |
|
|
10 |
|
|
|
11 |
layer_names = [layer.name for layer in model.layers] |
|
|
12 |
|
|
|
13 |
for name in layer_names[:-3]: # exclude the last convolution layer |
|
|
14 |
if 'conv3d' in name and 'transpose' not in name: |
|
|
15 |
self.assertIn(name.replace('conv3d', 'batch_normalization'), layer_names) |