Diff of /drunet/model/dr_unet.py [000000] .. [2824d6]

Switch to side-by-side view

--- a
+++ b/drunet/model/dr_unet.py
@@ -0,0 +1,109 @@
+from tensorflow import keras
+import tensorflow as tf
+
+
+def conv_layer(inputs, filters, kernel_size=3, strides=1, need_activate=True):
+    out = keras.layers.Conv2D(filters, kernel_size, strides, padding='same')(inputs)
+    out = keras.layers.BatchNormalization()(out)
+    if need_activate:
+        out = keras.layers.ELU()(out)
+    return out
+
+
+def block_1(inputs, filters):
+    out = conv_layer(inputs, filters // 4, 1)
+    out = conv_layer(out, filters // 4, 3)
+    out = conv_layer(out, filters // 4, 3, need_activate=False)
+    res = conv_layer(inputs, filters // 4, 1)
+    out = keras.layers.ELU()(out + res)
+    return out
+
+
+def block_2(inputs, filters):
+    out = conv_layer(inputs, filters // 4, 1)
+    out = conv_layer(out, filters // 4, 3)
+    out = conv_layer(out, filters // 4, 3, need_activate=False)
+    res = conv_layer(inputs, filters // 4, 1)
+    out = keras.layers.ELU()(out + res)
+    return out
+
+
+def block_3(inputs, filters):
+    out = conv_layer(inputs, filters // 4, 1)
+    out = conv_layer(out, filters // 4, 3)
+    out = conv_layer(out, filters, 3, need_activate=False)
+    res = conv_layer(inputs, filters, 1)
+    out = keras.layers.ELU()(out + res)
+    return out
+
+
+def dr_unet(input_shape=(256, 256, 1), dims=32):
+    inputs = keras.Input(input_shape)
+    out = conv_layer(inputs, 16, 1)
+
+    out = block_1(out, dims)
+    out_256 = block_3(out, dims)
+    out = keras.layers.MaxPool2D(2, 2, padding='same')(out_256)
+
+    out = block_1(out, dims * 2)
+    out_128 = block_3(out, dims * 2)
+    out = keras.layers.MaxPool2D(2, 2, padding='same')(out_128)
+
+    out = block_1(out, dims * 4)
+    out_64 = block_3(out, dims * 4)
+    out = keras.layers.MaxPool2D(2, 2, padding='same')(out_64)
+
+    out = block_1(out, dims * 8)
+    out_32 = block_3(out, dims * 8)
+    out = keras.layers.MaxPool2D(2, 2, padding='same')(out_32)
+
+    out = block_1(out, dims * 16)
+    out_16 = block_3(out, dims * 16)
+    out = keras.layers.MaxPool2D(2, 2, padding='same')(out_16)
+
+    out = block_1(out, dims * 32)
+    out = block_3(out, dims * 32)
+
+    up_16 = keras.layers.Conv2DTranspose(filters=dims * 16, kernel_size=2, strides=2, padding='same')(out)
+    up = keras.layers.Concatenate()([up_16, out_16])
+    up = keras.layers.BatchNormalization()(up)
+    up = keras.layers.ELU()(up)
+
+    up = block_2(up, dims * 16)
+    up = block_3(up, dims * 16)
+    up_32 = keras.layers.Conv2DTranspose(filters=dims * 8, kernel_size=2, strides=2, padding='same')(up)
+    up = keras.layers.Concatenate()([up_32, out_32])
+    up = keras.layers.BatchNormalization()(up)
+    up = keras.layers.ELU()(up)
+
+    up = block_2(up, dims * 8)
+    up = block_3(up, dims * 8)
+    up_64 = keras.layers.Conv2DTranspose(filters=dims * 4, kernel_size=2, strides=2, padding='same')(up)
+    up = keras.layers.Concatenate()([up_64, out_64])
+    up = keras.layers.BatchNormalization()(up)
+    up = keras.layers.ELU()(up)
+
+    up = block_2(up, dims * 4)
+    up = block_3(up, dims * 4)
+    up_128 = keras.layers.Conv2DTranspose(filters=dims * 2, kernel_size=2, strides=2, padding='same')(up)
+    up = keras.layers.Concatenate()([up_128, out_128])
+    up = keras.layers.BatchNormalization()(up)
+    up = keras.layers.ELU()(up)
+
+    up = block_2(up, dims * 2)
+    up = block_3(up, dims * 2)
+    up_256 = keras.layers.Conv2DTranspose(filters=dims * 1, kernel_size=2, strides=2, padding='same')(up)
+    up = keras.layers.Concatenate()([up_256, out_256])
+    up = keras.layers.BatchNormalization()(up)
+    up = keras.layers.ELU()(up)
+
+    up = block_2(up, dims)
+    up = block_3(up, dims)
+    up = keras.layers.Conv2D(filters=1, kernel_size=1, strides=(1, 1), padding='same')(up)
+    up = keras.activations.sigmoid(up)
+    return keras.Model(inputs, up)
+
+
+if __name__ == '__main__':
+    model = dr_unet()
+    model.summary()