Diff of /model.py [000000] .. [7b3b92]

Switch to unified view

a b/model.py
1
import tensorflow.keras.backend as K
2
from tensorflow.keras.models import Model
3
from tensorflow.keras import Input
4
from tensorflow.keras.layers import Conv2D, PReLU, UpSampling2D, concatenate , Reshape, Dense, Permute, MaxPool2D
5
from tensorflow.keras.layers import GlobalAveragePooling2D, Activation, add, GaussianNoise, BatchNormalization, multiply
6
from tensorflow.keras.optimizers import SGD
7
from loss import custom_loss
8
K.set_image_data_format("channels_last")
9
10
11
12
def unet_model(input_shape, modified_unet=True, learning_rate=0.01, start_channel=64, 
13
               number_of_levels=3, inc_rate=2, output_channels=4, saved_model_dir=None):
14
    """
15
    Builds UNet model
16
    
17
    Parameters
18
    ----------
19
    input_shape : tuple
20
        Shape of the input data (height, width, channel)
21
    modified_unet : bool
22
        Whether to use modified UNet or the original UNet
23
    learning_rate : float
24
        Learning rate for the model. The default is 0.01.
25
    start_channel : int
26
        Number of channels of the first conv. The default is 64.
27
    number_of_levels : int
28
        The depth size of the U-structure. The default is 3.
29
    inc_rate : int
30
        Rate at which the conv channels will increase. The default is 2.
31
    output_channels : int
32
        The number of output layer channels. The default is 4
33
    saved_model_dir : str
34
        If spesified, the model weights will be loaded from this path. The default is None.
35
36
    Returns
37
    -------
38
    model : keras.model
39
        The created keras model with respect to the input parameters
40
41
    """
42
43
        
44
    input_layer = Input(shape=input_shape, name='the_input_layer')
45
46
    if modified_unet:
47
        x = GaussianNoise(0.01, name='Gaussian_Noise')(input_layer)
48
        x = Conv2D(64, 2, padding='same')(x)
49
        x = level_block_modified(x, start_channel, number_of_levels, inc_rate)
50
        x = BatchNormalization(axis = -1)(x)
51
        x = PReLU(shared_axes=[1, 2])(x)
52
    else: 
53
        x = level_block(input_layer, start_channel, number_of_levels, inc_rate)
54
55
    x            = Conv2D(output_channels, 1, padding='same')(x)
56
    output_layer = Activation('softmax')(x)
57
    
58
    model        = Model(inputs = input_layer, outputs = output_layer)
59
60
    if modified_unet:
61
        print("The modified UNet was built!")
62
    else:
63
        print("The original UNet was built!")
64
65
    if saved_model_dir:
66
        model.load_weights(saved_model_dir)
67
        print("the model weights were successfully loaded!")
68
            
69
    sgd = SGD(lr=learning_rate, momentum=0.9, decay=0)
70
    model.compile(optimizer=sgd, loss=custom_loss)
71
    
72
    return model
73
74
75
def se_block(x, ratio=16):
76
    
77
    """
78
    creates a squeeze and excitation block
79
    https://arxiv.org/abs/1709.01507
80
    
81
    Parameters
82
    ----------
83
    x : tensor
84
        Input keras tensor
85
    ratio : int
86
        The reduction ratio. The default is 16.
87
88
    Returns
89
    -------
90
    x : tensor
91
        A keras tensor
92
    """
93
 
94
95
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
96
    filters = x.shape[channel_axis]
97
    se_shape = (1, 1, filters)
98
99
    se = GlobalAveragePooling2D()(x)
100
    se = Reshape(se_shape)(se)
101
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
102
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
103
104
    if K.image_data_format() == 'channels_first':
105
        se = Permute((3, 1, 2))(se)
106
107
    x = multiply([x, se])
108
    return x
109
110
111
def level_block(x, dim, level, inc):
112
    
113
    if level > 0:
114
        m = conv_layers(x, dim)
115
        x = MaxPool2D(pool_size=(2, 2))(m)
116
        x = level_block(x,int(inc*dim), level-1, inc)
117
        x = UpSampling2D(size=(2, 2))(x)
118
        x = Conv2D(dim, 2, padding='same')(x)
119
        m = concatenate([m,x])
120
        x = conv_layers(m, dim)
121
    else:
122
        x = conv_layers(x, dim)
123
    return x
124
125
126
def level_block_modified(x, dim, level, inc):
127
    
128
    if level > 0:
129
        m = res_block(x, dim, encoder_path=True)##########
130
        x = Conv2D(int(inc*dim), 2, strides=2, padding='same')(m)
131
        x = level_block_modified(x, int(inc*dim), level-1, inc)
132
133
        x = UpSampling2D(size=(2, 2))(x)
134
        x = Conv2D(dim, 2, padding='same')(x)
135
136
        m = concatenate([m,x])
137
        m = se_block(m, 8)
138
        x = res_block(m, dim, encoder_path=False)
139
    else:
140
        x = res_block(x, dim, encoder_path=True) #############
141
    return x
142
143
144
def conv_layers(x, dim):
145
146
    x = Conv2D(dim, 3, padding='same')(x)
147
    x = Activation("relu")(x)
148
149
    x = Conv2D(dim, 3, padding='same')(x)
150
    x = Activation("relu")(x)
151
152
    return x
153
154
def res_block(x, dim, encoder_path=True):
155
156
    m = BatchNormalization(axis = -1)(x)
157
    m = PReLU(shared_axes = [1, 2])(m)
158
    m = Conv2D(dim, 3, padding='same')(m)
159
160
    m = BatchNormalization(axis = -1)(m)
161
    m = PReLU(shared_axes = [1, 2])(m)
162
    m = Conv2D(dim, 3, padding='same')(m)
163
164
    if encoder_path:
165
        x = add([x, m])
166
    else:
167
        x = Conv2D(dim, 1, padding='same', use_bias=False)(x)
168
        x = add([x,m])
169
    return  x
170