a b/test/test_model.py
1
from unittest import TestCase
2
3
from fetal_net.model import unet_model_3d
4
5
6
class TestModel(TestCase):
7
    def test_batch_normalization(self):
8
        model = unet_model_3d(input_shape=(1, 16, 16, 16), depth=2, deconvolution=True, metrics=[], n_labels=1,
9
                              batch_normalization=True)
10
11
        layer_names = [layer.name for layer in model.layers]
12
13
        for name in layer_names[:-3]:  # exclude the last convolution layer
14
            if 'conv3d' in name and 'transpose' not in name:
15
                self.assertIn(name.replace('conv3d', 'batch_normalization'), layer_names)