a b/model_definition/baseline.py
1
import tensorflow as tf
2
import config
3
4
from model_utils import calculate_conv_output_size
5
6
7
n_x = config.IMAGE_PXL_SIZE_X
8
n_y = config.IMAGE_PXL_SIZE_Y
9
n_z = config.SLICES
10
11
# This handles padding in both convolution and pooling layers
12
strides = [[1, 1, 1],
13
           [2, 4, 4],
14
           [1, 1, 1],
15
           [2, 2, 2],
16
           [1, 1, 1],
17
           [2, 2, 2]]
18
19
filters = [[3, 5, 5],
20
            [3, 5, 5],
21
            [3, 3, 3],
22
            [3, 3, 3],
23
            [3, 3, 3],
24
            [3, 3, 3]]
25
            
26
padding_types = ['VALID'] * 6
27
28
29
baseline_config = {
30
    'weights': [
31
        # Convolution layers
32
        ('wc1', tf.truncated_normal([3, 5, 5, config.NUM_CHANNELS, 16], stddev=0.01)),
33
        ('wc2', tf.truncated_normal([3, 3, 3, 16, 32], stddev=0.01)),
34
        ('wc3', tf.truncated_normal([3, 3, 3, 32, 32], stddev=0.01)),
35
        # Fully connected layers
36
        ('wd1', tf.truncated_normal([calculate_conv_output_size(n_x, n_y, n_z, 
37
                                                                strides, 
38
                                                                filters,
39
                                                                padding_types, 
40
                                                                32), 
41
                                    100], stddev=0.01)),
42
        ('wout', tf.truncated_normal([100, config.N_CLASSES], stddev=0.01))
43
    ],
44
    'biases': [
45
        # Convolution layers
46
        ('bc1', tf.zeros([16])),
47
        ('bc2', tf.constant(1.0, shape=[32])),
48
        ('bc3', tf.zeros([32])),
49
        # Fully connected layers
50
        ('bd1', tf.constant(1.0, shape=[100])),
51
        ('bout', tf.constant(1.0, shape=[config.N_CLASSES]))
52
    ],
53
    'pool_strides': [
54
        [1, 2, 4, 4, 1],
55
        [1, 2, 2, 2, 1],
56
        [1, 2, 2, 2, 1],
57
    ],
58
    'pool_windows': [
59
        [1, 3, 5, 5, 1],
60
        [1, 3, 3, 3, 1],
61
        [1, 3, 3, 3, 1],
62
    ],
63
    'strides': [
64
        [1, 1, 1, 1, 1],
65
        [1, 1, 1, 1, 1],
66
        [1, 1, 1, 1, 1],
67
    ]
68
}