a b/model_definition/additional_layers.py
1
import tensorflow as tf
2
import config
3
4
from model_utils import calculate_conv_output_size
5
6
7
n_x = config.IMAGE_PXL_SIZE_X
8
n_y = config.IMAGE_PXL_SIZE_Y
9
n_z = config.SLICES
10
11
# This handles padding in both convolution and pooling layers
12
strides = [[1, 1, 1],
13
           [2, 4, 4],
14
           [1, 1, 1],
15
           [2, 2, 2],
16
           [1, 1, 1],
17
           [1, 1, 1],
18
           [2, 2, 2]]
19
20
filters = [[3, 5, 5],
21
            [3, 5, 5],
22
            [3, 3, 3],
23
            [3, 3, 3],
24
            [3, 3, 3],
25
            [3, 3, 3],
26
            [3, 3, 3]]
27
            
28
padding_types = ['VALID'] * 7
29
30
31
additional_layers_config = {
32
    'weights': [
33
        # Convolution layers
34
        ('wc1', tf.truncated_normal([3, 5, 5, config.NUM_CHANNELS, 16], stddev=0.01)),
35
        ('wc2', tf.truncated_normal([3, 3, 3, 16, 64], stddev=0.01)),
36
        ('wc3', tf.truncated_normal([3, 3, 3, 64, 64], stddev=0.01)),
37
        ('wc4', tf.truncated_normal([3, 3, 3, 64, 32], stddev=0.01)),
38
        # Fully connected layers
39
        ('wd1', tf.truncated_normal([calculate_conv_output_size(n_x, n_y, n_z, 
40
                                                                strides, 
41
                                                                filters,
42
                                                                padding_types, 
43
                                                                32), 
44
                                    100], stddev=0.01)),
45
        ('wd2', tf.truncated_normal([100, 50], stddev=0.01)),
46
        ('wout', tf.truncated_normal([50, config.N_CLASSES], stddev=0.01))
47
    ],
48
    'biases': (
49
        # Convolution layers
50
        ('bc1', tf.zeros([16])),
51
        ('bc2', tf.constant(1.0, shape=[64])),
52
        ('bc3', tf.zeros([64])),
53
        ('bc4', tf.constant(1.0, shape=[32])),
54
        # Fully connected layers
55
        ('bd1', tf.constant(1.0, shape=[100])),
56
        ('bd2', tf.constant(1.0, shape=[50])),
57
        ('bout', tf.constant(1.0, shape=[config.N_CLASSES]))
58
    ),
59
    'pool_strides': [
60
        [1, 2, 4, 4, 1],
61
        [1, 2, 2, 2, 1],
62
        [],
63
        [1, 2, 2, 2, 1],
64
    ],
65
    'pool_windows': [
66
        [1, 3, 5, 5, 1],
67
        [1, 3, 3, 3, 1],
68
        [],
69
        [1, 3, 3, 3, 1],
70
    ],
71
    'strides': [
72
        [1, 1, 1, 1, 1],
73
        [1, 1, 1, 1, 1],
74
        [1, 1, 1, 1, 1],
75
        [1, 1, 1, 1, 1],
76
    ]
77
}