Switch to unified view

a b/model_definition/default.py
1
import tensorflow as tf
2
import config
3
4
from model_utils import calculate_conv_output_size
5
6
7
n_x = config.IMAGE_PXL_SIZE_X
8
n_y = config.IMAGE_PXL_SIZE_Y
9
n_z = config.SLICES
10
11
# This handles padding in both convolution and pooling layers
12
strides = [[2, 2, 2],
13
           [2, 3, 3],
14
           [1, 1, 1],
15
           [2, 2, 2],
16
           [1, 1, 1],
17
           [1, 1, 1],
18
           [2, 2, 2]]
19
20
filters = [[4, 5, 5],
21
            [3, 4, 4],
22
            [3, 3, 3],
23
            [3, 3, 3],
24
            [3, 3, 3],
25
            [3, 3, 3],
26
            [3, 3, 3]]
27
            
28
padding_types = ['VALID'] * 7
29
30
31
# Default network config used with more slices
32
# and larger convololution stride on first layer
33
default_config = {
34
    'weights': [
35
        # Convolution layers
36
        ('wc1', tf.truncated_normal([4, 5, 5, config.NUM_CHANNELS, 16], stddev=0.01)),
37
        ('wc2', tf.truncated_normal([3, 3, 3, 16, 64], stddev=0.01)),
38
        ('wc3', tf.truncated_normal([3, 3, 3, 64, 64], stddev=0.01)),
39
        ('wc4', tf.truncated_normal([3, 3, 3, 64, 32], stddev=0.01)),
40
        # Fully connected layers
41
        ('wd1', tf.truncated_normal([calculate_conv_output_size(n_x, n_y, n_z, 
42
                                                                strides, 
43
                                                                filters,
44
                                                                padding_types, 
45
                                                                32), 
46
                                    100], stddev=0.01)),
47
        ('wd2', tf.truncated_normal([100, 50], stddev=0.01)),
48
        ('wout', tf.truncated_normal([50, config.N_CLASSES], stddev=0.01))
49
    ],
50
    'biases': (
51
        # Convolution layers
52
        ('bc1', tf.zeros([16])),
53
        ('bc2', tf.constant(1.0, shape=[64])),
54
        ('bc3', tf.zeros([64])),
55
        ('bc4', tf.constant(1.0, shape=[32])),
56
        # Fully connected layers
57
        ('bd1', tf.constant(1.0, shape=[100])),
58
        ('bd2', tf.constant(1.0, shape=[50])),
59
        ('bout', tf.constant(1.0, shape=[config.N_CLASSES]))
60
    ),
61
    'pool_strides': [
62
        [1, 2, 3, 3, 1],
63
        [1, 2, 2, 2, 1],
64
        [],
65
        [1, 2, 2, 2, 1],
66
    ],
67
    'pool_windows': [
68
        [1, 3, 4, 4, 1],
69
        [1, 3, 3, 3, 1],
70
        [],
71
        [1, 3, 3, 3, 1],
72
    ],
73
    'strides': [
74
        [1, 2, 2, 2, 1],
75
        [1, 1, 1, 1, 1],
76
        [1, 1, 1, 1, 1],
77
        [1, 1, 1, 1, 1],
78
    ]
79
}