Diff of /models.py [000000] .. [1b6491]

Switch to unified view

a b/models.py
1
2
# ==============================================================================
3
# Copyright (C) 2020 Vladimir Juras, Ravinder Regatte and Cem M. Deniz
4
#
5
# This file is part of 2019_IWOAI_Challenge
6
#
7
# This program is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Affero General Public License as published
9
# by the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
11
12
# This program is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
# GNU Affero General Public License for more details.
16
17
# You should have received a copy of the GNU Affero General Public License
18
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
19
# ==============================================================================
20
import tensorflow as tf
21
import tf_layers as tflay
22
from functools import partial
23
24
25
def inference_unet4(x, reg_c=0.1, keep_prob=0.5, channels=1, n_class=2, features_root=1, filter_size=3, pool_size=2, summaries=True, resnet=True):
26
    '''
27
    Creates a new convolutional net for the given parametrization.
28
29
    :param x: input tensor, shape [?,nx,ny,nz,channels]
30
    :param keep_prob: dropout probability tensor
31
    :param channels: number of channels in the input image
32
    :param n_class: number of output labels
33
    :param layers: number of layers in the net
34
    :param features_root: number of features in the first layer
35
    :param filter_size: size of the convolution filter
36
    :param pool_size: size of the max pooling operation
37
    :param summaries: Flag if summaries should be created
38
    '''
39
40
    print('Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}'.format(layers=4,
41
                                                                                                           features=features_root,
42
                                                                                                           filter_size=filter_size,
43
                                                                                                           pool_size=pool_size))
44
    if not resnet:
45
        add_res = partial(tflay.add_res, skip=True)
46
    else:
47
        add_res = partial(tflay.add_res, skip=False)
48
49
    # Placeholder for the input image
50
    nx, ny, nz, channels = x.get_shape()[-4:]
51
    x_image = tf.reshape(x, tf.stack([-1,nx,ny,nz,channels]))
52
    shape_u0a = [nx, ny, nz]
53
    shape_u1a = [(n+1)//2 for n in shape_u0a]
54
    shape_u2a = [(n+1)//2 for n in shape_u1a]
55
    shape_u3a = [(n+1)//2 for n in shape_u2a]
56
    
57
    batch_size = tf.shape(x_image)[0]
58
59
    d0a = tflay.relu('relu_d0a', tflay.conv3d('conv_d0a', x_image, features_root, reg_constant=reg_c))
60
    d0b = tflay.relu('relu_d0b', add_res('res_d0b', tflay.conv3d('conv_d0b', d0a, features_root, reg_constant=reg_c), x_image, conv=False)) # 128 * 128 * 48, n
61
62
    d1a = tflay.max_pool('pool_d1a', d0b) # 64 * 64 * 24, n
63
    d1b = tflay.relu('relu_d1b', tflay.conv3d('conv_d1b', d1a, 2**1*features_root, reg_constant=reg_c)) # 64 * 64 * 24, 2n
64
    d1c = tflay.relu('relu_d1c', add_res('res_d1c', tflay.conv3d('conv_d1b-c', d1b, 2**1*features_root, reg_constant=reg_c), d1a)) # 64 * 64 * 24, 2n
65
66
    d2a = tflay.max_pool('pool_d2a', d1c) # 32 * 32 * 12, 2n
67
    d2b = tflay.relu('relu_d2b', tflay.conv3d('conv_d2b', d2a, 2**2*features_root, reg_constant=reg_c)) # 32 * 32 * 12, 4n
68
    d2c = tflay.relu('relu_d2c', add_res('res_d2c', tflay.conv3d('conv_d2b-c', d2b, 2**2*features_root, reg_constant=reg_c), d2a)) # 32 * 32 * 12, 4n
69
70
    d3a = tflay.max_pool('pool_d3a', d2c) # 16 * 16 * 6, 4n
71
    d3b = tflay.relu('relu_d3b', tflay.conv3d('conv_d3b', d3a, 2**3*features_root, reg_constant=reg_c)) # 16 * 16 * 6, 8n
72
    d3c = tflay.relu('relu_d3c', add_res('res_d3c', tflay.conv3d('conv_d3b-c', d3b, 2**3*features_root, reg_constant=reg_c), d3a)) # 16 * 16 * 6, 8n
73
    d3c = tflay.dropout('dropout_d3c', d3c, keep_prob)
74
75
    d4a = tflay.max_pool('pool_d4a', d3c) # 8 * 8 * 3, 8n
76
    d4b = tflay.relu('relu_d4b', tflay.conv3d('conv_d4b', d4a, 2**4*features_root, kernel_size=[3,3,1], reg_constant=reg_c)) # 8 * 8 * 3, 16n
77
    d4c = tflay.relu('relu_d4c', add_res('res_d4c', tflay.conv3d('conv_d4b-c', d4b, 2**4*features_root, reg_constant=reg_c), d4a)) # 8 * 8 * 3, 16n
78
    d4c = tflay.dropout('dropout_d4c', d4c, keep_prob)
79
80
    u3a = tflay.concat('concat_u3a', tflay.relu('relu_u3a', tflay.upconv3d('upconv_u3a', d4c, 2**3*features_root, shape_u3a, reg_constant=reg_c)), d3c) # 16 * 16 * 6, 16n
81
    u3b = tflay.relu('relu_u3b', tflay.conv3d('conv_u3a-b', u3a, 2**3*features_root, reg_constant=reg_c)) # 16 * 16 * 6, 8n
82
    u3c = tflay.relu('relu_u3c', add_res('res_u3c', tflay.conv3d('conv_u3b-c', u3b, 2**3*features_root, reg_constant=reg_c), u3a)) # 16 * 16 * 6, 8n
83
84
    u2a = tflay.concat('concat_u2a', tflay.relu('relu_u2a', tflay.upconv3d('upconv_u2a', u3c, 2**2*features_root, shape_u2a, reg_constant=reg_c)), d2c) # 32 * 32 * 12, 8n
85
    u2b = tflay.relu('relu_u2b', tflay.conv3d('conv_u2a-b', u2a, 2**2*features_root, reg_constant=reg_c)) # 32 * 32 * 12, 4n
86
    u2c = tflay.relu('relu_u2c', add_res('res_u2c', tflay.conv3d('conv_u2b-c', u2b, 2**2*features_root, reg_constant=reg_c), u2a)) # 32 * 32 * 12, 4n
87
88
    u1a = tflay.concat('concat_u1a', tflay.relu('relu_u1a', tflay.upconv3d('upconv_u1a', u2c, 2**1*features_root, shape_u1a, reg_constant=reg_c)), d1c) # 64 * 64 * 24, 4n
89
    u1b = tflay.relu('relu_u1b', tflay.conv3d('conv_u1a-b', u1a, 2**1*features_root, reg_constant=reg_c)) # 64 * 64 * 24, 2n
90
    u1c = tflay.relu('relu_u1c', add_res('res_u1c', tflay.conv3d('conv_u1b-c', u1b, 2**1*features_root, reg_constant=reg_c), u1a)) # 64 * 64 * 24, 2n
91
92
    u0a = tflay.concat('concat_u0a', tflay.relu('relu_u0a', tflay.upconv3d('upconv_u0a', u1c, 2**0*features_root, shape_u0a, reg_constant=reg_c)), d0b) # 128 * 128 * 48, 2n
93
    u0b = tflay.relu('relu_u0b', tflay.conv3d('conv_u0a-b', u0a, 2**0*features_root, reg_constant=reg_c)) # 128 * 128 * 48, n
94
    u0c = tflay.relu('relu_u0c', add_res('res_u0c', tflay.conv3d('conv_u0b-c', u0b, 2**0*features_root, reg_constant=reg_c, padding='VALID'), u0a)) # 128 * 128 * 48, n
95
96
    score = tflay.relu('relu_result', tflay.conv3d('conv_result', u0c, n_class, kernel_size=[1,1,1], reg_constant=reg_c))
97
98
    return score
99
100
def inference_atrous4(x, reg_c=0.1, keep_prob=0.5, channels=1, n_class=2, features_root=1, filter_size=3, pool_size=2, dilation_rates=[2], summaries=True, resnet=True):
101
    '''
102
    Creates a new convolutional net for the given parametrization.
103
104
    :param x: input tensor, shape [?,nx,ny,nz,channels]
105
    :param keep_prob: dropout probability tensor
106
    :param channels: number of channels in the input image
107
    :param n_class: number of output labels
108
    :param layers: number of layers in the net
109
    :param features_root: number of features in the first layer
110
    :param filter_size: size of the convolution filter
111
    :param pool_size: size of the max pooling operation
112
    :param summaries: Flag if summaries should be created
113
    '''
114
115
    print('Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}'.format(layers=5,
116
                                                                                                           features=features_root,
117
                                                                                                           filter_size=filter_size,
118
                                                                                                           pool_size=pool_size))
119
    if not resnet:
120
        add_res = partial(tflay.add_res, skip=True)
121
    else:
122
        add_res = partial(tflay.add_res, skip=False)
123
124
    # Placeholder for the input image
125
    nx, ny, nz, channels = x.get_shape()[-4:]
126
    x_image = tf.reshape(x, tf.stack([-1,nx,ny,nz,channels]))
127
    shape_u0a = [nx, ny, nz]
128
    shape_u1a = [(n+1)//2 for n in shape_u0a]
129
    shape_u2a = [(n+1)//2 for n in shape_u1a]
130
    shape_u3a = [(n+1)//2 for n in shape_u2a]
131
    shape_u4a = [(n+1)//2 for n in shape_u3a]
132
133
    batch_size = tf.shape(x_image)[0]
134
135
    d0a = tflay.relu('relu_d0a', tflay.conv3d('conv_d0a', x_image, features_root, reg_constant=reg_c))
136
    d0b = tflay.relu('relu_d0b', add_res('res_d0b', tflay.conv3d('conv_d0b', d0a, features_root, reg_constant=reg_c), x_image, conv=False)) # 128 * 128 * 48, n
137
138
    d1a = tflay.max_pool('pool_d1a', d0b) # 64 * 64 * 24, n
139
    d1b = tflay.relu('relu_d1b', tflay.conv3d('conv_d1b', d1a, 2**1*features_root, reg_constant=reg_c)) # 64 * 64 * 24, 2n
140
    d1c = tflay.relu('relu_d1c', add_res('res_d1c', tflay.conv3d('conv_d1b-c', d1b, 2**1*features_root, reg_constant=reg_c), d1a)) # 64 * 64 * 24, 2n
141
142
    d2a = tflay.max_pool('pool_d2a', d1c) # 32 * 32 * 12, 2n
143
    d2b = tflay.relu('relu_d2b', tflay.conv3d('conv_d2b', d2a, 2**2*features_root, reg_constant=reg_c)) # 32 * 32 * 12, 4n
144
    d2c = tflay.relu('relu_d2c', add_res('res_d2c', tflay.conv3d('conv_d2b-c', d2b, 2**2*features_root, reg_constant=reg_c), d2a)) # 32 * 32 * 12, 4n
145
146
    d3a = tflay.max_pool('pool_d3a', d2c) # 16 * 16 * 6, 4n
147
    d3b = tflay.relu('relu_d3b', tflay.conv3d('conv_d3b', d3a, 2**3*features_root, reg_constant=reg_c)) # 16 * 16 * 6, 8n
148
    d3c = tflay.relu('relu_d3c', add_res('res_d3c', tflay.conv3d('conv_d3b-c', d3b, 2**3*features_root, reg_constant=reg_c), d3a)) # 16 * 16 * 6, 8n
149
150
    d4a = tflay.max_pool('pool_d4a', d3c) # 8 * 8 * 3, 8n
151
    d4b = tflay.relu('relu_d4b', tflay.conv3d('conv_d4b', d4a, 2**4*features_root, reg_constant=reg_c)) # 8 * 8 * 3, 16n
152
    d4c = tflay.relu('relu_d4c', add_res('res_d4c', tflay.conv3d('conv_d4b-c', d4b, 2**4*features_root, reg_constant=reg_c), d4a)) # 8 * 8 * 3, 16n
153
    d4c = tflay.dropout('dropout_d4c', d4c, keep_prob)
154
155
    bs = [d4c]
156
    for i, dilation_rate in enumerate(dilation_rates):
157
        name = 'b' + str(i) + 'a'
158
        tmp = tflay.relu('relu_'+name, tflay.atrousconv3d('atrsconv_'+name, d4c, 2**4*features_root, dilation_rate=[dilation_rate,dilation_rate,1], reg_constant=reg_c)) # 8 * 8 * 3, 16n
159
        bs.append(tmp)
160
    bna = tflay.multiconcat('concat_bna', bs)
161
    bna = tflay.dropout('dropout_bna', bna, keep_prob)
162
163
    u3a = tflay.concat('concat_u3a', tflay.relu('relu_u3a', tflay.upconv3d('upconv_u3a', bna, 2**3*features_root, shape_u3a, reg_constant=reg_c)), d3c) # 16 * 16 * 6, 16n
164
    u3b = tflay.relu('relu_u3b', tflay.conv3d('conv_u3a-b', u3a, 2**3*features_root, reg_constant=reg_c)) # 16 * 16 * 6, 8n
165
    u3c = tflay.relu('relu_u3c', add_res('res_u3c', tflay.conv3d('conv_u3b-c', u3b, 2**3*features_root, reg_constant=reg_c), u3a)) # 16 * 16 * 6, 8n
166
167
    u2a = tflay.concat('concat_u2a', tflay.relu('relu_u2a', tflay.upconv3d('upconv_u2a', u3c, 2**2*features_root, shape_u2a, reg_constant=reg_c)), d2c) # 32 * 32 * 12, 8n
168
    u2b = tflay.relu('relu_u2b', tflay.conv3d('conv_u2a-b', u2a, 2**2*features_root, reg_constant=reg_c)) # 32 * 32 * 12, 4n
169
    u2c = tflay.relu('relu_u2c', add_res('res_u2c', tflay.conv3d('conv_u2b-c', u2b, 2**2*features_root, reg_constant=reg_c), u2a)) # 32 * 32 * 12, 4n
170
171
    u1a = tflay.concat('concat_u1a', tflay.relu('relu_u1a', tflay.upconv3d('upconv_u1a', u2c, 2**1*features_root, shape_u1a, reg_constant=reg_c)), d1c) # 64 * 64 * 24, 4n
172
    u1b = tflay.relu('relu_u1b', tflay.conv3d('conv_u1a-b', u1a, 2**1*features_root, reg_constant=reg_c)) # 64 * 64 * 24, 2n
173
    u1c = tflay.relu('relu_u1c', add_res('res_u1c', tflay.conv3d('conv_u1b-c', u1b, 2**1*features_root, reg_constant=reg_c), u1a)) # 64 * 64 * 24, 2n
174
175
    u0a = tflay.concat('concat_u0a', tflay.relu('relu_u0a', tflay.upconv3d('upconv_u0a', u1c, 2**0*features_root, shape_u0a, reg_constant=reg_c)), d0b) # 128 * 128 * 48, 2n
176
    u0b = tflay.relu('relu_u0b', tflay.conv3d('conv_u0a-b', u0a, 2**0*features_root, reg_constant=reg_c)) # 128 * 128 * 48, n
177
    u0c = tflay.relu('relu_u0c', add_res('res_u0c', tflay.conv3d('conv_u0b-c', u0b, 2**0*features_root, reg_constant=reg_c, padding='VALID'), u0a)) # 128 * 128 * 48, n
178
179
    score = tflay.relu('relu_result', tflay.conv3d('conv_result', u0c, n_class, kernel_size=[1,1,1], reg_constant=reg_c))
180
181
    return score