Diff of /utils/models.py [000000] .. [06a92b]

Switch to unified view

a b/utils/models.py
1
# Authors:
2
# Akshay Chaudhari and Zhongnan Fang
3
# May 2018
4
# akshaysc@stanford.edu
5
6
from __future__ import print_function, division
7
8
import numpy as np
9
import pickle
10
import json
11
import math
12
import os
13
14
from keras.models import Model, Sequential
15
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, add, Lambda, Dropout, AlphaDropout
16
from keras.layers import BatchNormalization as BN
17
from keras.utils import plot_model
18
from keras import backend as K
19
import tensorflow as tf
20
21
def unet_2d_model(input_size):
22
23
    # input size is a tuple of the size of the image
24
    # assuming channel last
25
    # input_size = (dim1, dim2, dim3, ch)
26
    # unet begins
27
28
    nfeatures = [2**feat*32 for feat in np.arange(6)]
29
    depth = len(nfeatures)    
30
31
    conv_ptr = []
32
33
    # input layer
34
    inputs = Input(input_size)
35
36
    # step down convolutional layers 
37
    pool = inputs
38
    for depth_cnt in xrange(depth):
39
40
        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
41
                      padding='same', 
42
                      activation='relu',
43
                      kernel_initializer='he_normal')(pool)
44
        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
45
                      padding='same', 
46
                      activation='relu',
47
                      kernel_initializer='he_normal')(conv)
48
49
        conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv)
50
        conv = Dropout(rate=0.0)(conv)
51
52
        conv_ptr.append(conv)
53
54
        # Only maxpool till penultimate depth
55
        if depth_cnt < depth-1:
56
57
            # If size of input is odd, only do a 3x3 max pool
58
            xres = conv.shape.as_list()[1]
59
            if (xres % 2 == 0):
60
                pooling_size = (2,2)
61
            elif (xres % 2 == 1):
62
                pooling_size = (3,3)
63
64
            pool = MaxPooling2D(pool_size=pooling_size)(conv)
65
66
67
    # step up convolutional layers
68
    for depth_cnt in xrange(depth-2,-1,-1):
69
70
        deconv_shape = conv_ptr[depth_cnt].shape.as_list()
71
        deconv_shape[0] = None
72
73
        # If size of input is odd, then do a 3x3 deconv  
74
        if (deconv_shape[1] % 2 == 0):
75
            unpooling_size = (2,2)
76
        elif (deconv_shape[1] % 2 == 1):
77
            unpooling_size = (3,3)
78
79
        up = concatenate([Conv2DTranspose(nfeatures[depth_cnt],(3,3),
80
                          padding='same',
81
                          strides=unpooling_size,
82
                          output_shape=deconv_shape)(conv),
83
                          conv_ptr[depth_cnt]], 
84
                          axis=3)
85
86
        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
87
                      padding='same', 
88
                      activation='relu',
89
                      kernel_initializer='he_normal')(up)
90
        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
91
                      padding='same', 
92
                      activation='relu',
93
                      kernel_initializer='he_normal')(conv)
94
95
        conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv)
96
        conv = Dropout(rate=0.00)(conv)
97
98
    # combine features
99
    recon = Conv2D(1, (1,1), padding='same', activation='sigmoid')(conv)
100
101
    model = Model(inputs=[inputs], outputs=[recon])
102
    plot_model(model, to_file='unet2d.png',show_shapes=True)
103
    
104
    return model
105
106
107
if __name__ == '__main__':
108
  
109
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
110
    img_size = (288,288,1)
111
    model = unet_2d_model(img_size)
112
    print(model.summary())
113