Diff of /utils/models.py [000000] .. [06a92b]

Switch to side-by-side view

--- a
+++ b/utils/models.py
@@ -0,0 +1,113 @@
+# Authors:
+# Akshay Chaudhari and Zhongnan Fang
+# May 2018
+# akshaysc@stanford.edu
+
+from __future__ import print_function, division
+
+import numpy as np
+import pickle
+import json
+import math
+import os
+
+from keras.models import Model, Sequential
+from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, add, Lambda, Dropout, AlphaDropout
+from keras.layers import BatchNormalization as BN
+from keras.utils import plot_model
+from keras import backend as K
+import tensorflow as tf
+
+def unet_2d_model(input_size):
+
+    # input size is a tuple of the size of the image
+    # assuming channel last
+    # input_size = (dim1, dim2, dim3, ch)
+    # unet begins
+
+    nfeatures = [2**feat*32 for feat in np.arange(6)]
+    depth = len(nfeatures)    
+
+    conv_ptr = []
+
+    # input layer
+    inputs = Input(input_size)
+
+    # step down convolutional layers 
+    pool = inputs
+    for depth_cnt in xrange(depth):
+
+        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
+                      padding='same', 
+                      activation='relu',
+                      kernel_initializer='he_normal')(pool)
+        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
+                      padding='same', 
+                      activation='relu',
+                      kernel_initializer='he_normal')(conv)
+
+        conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv)
+        conv = Dropout(rate=0.0)(conv)
+
+        conv_ptr.append(conv)
+
+        # Only maxpool till penultimate depth
+        if depth_cnt < depth-1:
+
+            # If size of input is odd, only do a 3x3 max pool
+            xres = conv.shape.as_list()[1]
+            if (xres % 2 == 0):
+                pooling_size = (2,2)
+            elif (xres % 2 == 1):
+                pooling_size = (3,3)
+
+            pool = MaxPooling2D(pool_size=pooling_size)(conv)
+
+
+    # step up convolutional layers
+    for depth_cnt in xrange(depth-2,-1,-1):
+
+        deconv_shape = conv_ptr[depth_cnt].shape.as_list()
+        deconv_shape[0] = None
+
+        # If size of input is odd, then do a 3x3 deconv  
+        if (deconv_shape[1] % 2 == 0):
+            unpooling_size = (2,2)
+        elif (deconv_shape[1] % 2 == 1):
+            unpooling_size = (3,3)
+
+        up = concatenate([Conv2DTranspose(nfeatures[depth_cnt],(3,3),
+                          padding='same',
+                          strides=unpooling_size,
+                          output_shape=deconv_shape)(conv),
+                          conv_ptr[depth_cnt]], 
+                          axis=3)
+
+        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
+                      padding='same', 
+                      activation='relu',
+                      kernel_initializer='he_normal')(up)
+        conv = Conv2D(nfeatures[depth_cnt], (3,3), 
+                      padding='same', 
+                      activation='relu',
+                      kernel_initializer='he_normal')(conv)
+
+        conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv)
+        conv = Dropout(rate=0.00)(conv)
+
+    # combine features
+    recon = Conv2D(1, (1,1), padding='same', activation='sigmoid')(conv)
+
+    model = Model(inputs=[inputs], outputs=[recon])
+    plot_model(model, to_file='unet2d.png',show_shapes=True)
+    
+    return model
+
+
+if __name__ == '__main__':
+  
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+    img_size = (288,288,1)
+    model = unet_2d_model(img_size)
+    print(model.summary())
+