|
a |
|
b/utils/models.py |
|
|
1 |
# Authors: |
|
|
2 |
# Akshay Chaudhari and Zhongnan Fang |
|
|
3 |
# May 2018 |
|
|
4 |
# akshaysc@stanford.edu |
|
|
5 |
|
|
|
6 |
from __future__ import print_function, division |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
import pickle |
|
|
10 |
import json |
|
|
11 |
import math |
|
|
12 |
import os |
|
|
13 |
|
|
|
14 |
from keras.models import Model, Sequential |
|
|
15 |
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, add, Lambda, Dropout, AlphaDropout |
|
|
16 |
from keras.layers import BatchNormalization as BN |
|
|
17 |
from keras.utils import plot_model |
|
|
18 |
from keras import backend as K |
|
|
19 |
import tensorflow as tf |
|
|
20 |
|
|
|
21 |
def unet_2d_model(input_size): |
|
|
22 |
|
|
|
23 |
# input size is a tuple of the size of the image |
|
|
24 |
# assuming channel last |
|
|
25 |
# input_size = (dim1, dim2, dim3, ch) |
|
|
26 |
# unet begins |
|
|
27 |
|
|
|
28 |
nfeatures = [2**feat*32 for feat in np.arange(6)] |
|
|
29 |
depth = len(nfeatures) |
|
|
30 |
|
|
|
31 |
conv_ptr = [] |
|
|
32 |
|
|
|
33 |
# input layer |
|
|
34 |
inputs = Input(input_size) |
|
|
35 |
|
|
|
36 |
# step down convolutional layers |
|
|
37 |
pool = inputs |
|
|
38 |
for depth_cnt in xrange(depth): |
|
|
39 |
|
|
|
40 |
conv = Conv2D(nfeatures[depth_cnt], (3,3), |
|
|
41 |
padding='same', |
|
|
42 |
activation='relu', |
|
|
43 |
kernel_initializer='he_normal')(pool) |
|
|
44 |
conv = Conv2D(nfeatures[depth_cnt], (3,3), |
|
|
45 |
padding='same', |
|
|
46 |
activation='relu', |
|
|
47 |
kernel_initializer='he_normal')(conv) |
|
|
48 |
|
|
|
49 |
conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) |
|
|
50 |
conv = Dropout(rate=0.0)(conv) |
|
|
51 |
|
|
|
52 |
conv_ptr.append(conv) |
|
|
53 |
|
|
|
54 |
# Only maxpool till penultimate depth |
|
|
55 |
if depth_cnt < depth-1: |
|
|
56 |
|
|
|
57 |
# If size of input is odd, only do a 3x3 max pool |
|
|
58 |
xres = conv.shape.as_list()[1] |
|
|
59 |
if (xres % 2 == 0): |
|
|
60 |
pooling_size = (2,2) |
|
|
61 |
elif (xres % 2 == 1): |
|
|
62 |
pooling_size = (3,3) |
|
|
63 |
|
|
|
64 |
pool = MaxPooling2D(pool_size=pooling_size)(conv) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
# step up convolutional layers |
|
|
68 |
for depth_cnt in xrange(depth-2,-1,-1): |
|
|
69 |
|
|
|
70 |
deconv_shape = conv_ptr[depth_cnt].shape.as_list() |
|
|
71 |
deconv_shape[0] = None |
|
|
72 |
|
|
|
73 |
# If size of input is odd, then do a 3x3 deconv |
|
|
74 |
if (deconv_shape[1] % 2 == 0): |
|
|
75 |
unpooling_size = (2,2) |
|
|
76 |
elif (deconv_shape[1] % 2 == 1): |
|
|
77 |
unpooling_size = (3,3) |
|
|
78 |
|
|
|
79 |
up = concatenate([Conv2DTranspose(nfeatures[depth_cnt],(3,3), |
|
|
80 |
padding='same', |
|
|
81 |
strides=unpooling_size, |
|
|
82 |
output_shape=deconv_shape)(conv), |
|
|
83 |
conv_ptr[depth_cnt]], |
|
|
84 |
axis=3) |
|
|
85 |
|
|
|
86 |
conv = Conv2D(nfeatures[depth_cnt], (3,3), |
|
|
87 |
padding='same', |
|
|
88 |
activation='relu', |
|
|
89 |
kernel_initializer='he_normal')(up) |
|
|
90 |
conv = Conv2D(nfeatures[depth_cnt], (3,3), |
|
|
91 |
padding='same', |
|
|
92 |
activation='relu', |
|
|
93 |
kernel_initializer='he_normal')(conv) |
|
|
94 |
|
|
|
95 |
conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) |
|
|
96 |
conv = Dropout(rate=0.00)(conv) |
|
|
97 |
|
|
|
98 |
# combine features |
|
|
99 |
recon = Conv2D(1, (1,1), padding='same', activation='sigmoid')(conv) |
|
|
100 |
|
|
|
101 |
model = Model(inputs=[inputs], outputs=[recon]) |
|
|
102 |
plot_model(model, to_file='unet2d.png',show_shapes=True) |
|
|
103 |
|
|
|
104 |
return model |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
if __name__ == '__main__': |
|
|
108 |
|
|
|
109 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
110 |
img_size = (288,288,1) |
|
|
111 |
model = unet_2d_model(img_size) |
|
|
112 |
print(model.summary()) |
|
|
113 |
|