--- a +++ b/drunet/model/dr_unet.py @@ -0,0 +1,109 @@ +from tensorflow import keras +import tensorflow as tf + + +def conv_layer(inputs, filters, kernel_size=3, strides=1, need_activate=True): + out = keras.layers.Conv2D(filters, kernel_size, strides, padding='same')(inputs) + out = keras.layers.BatchNormalization()(out) + if need_activate: + out = keras.layers.ELU()(out) + return out + + +def block_1(inputs, filters): + out = conv_layer(inputs, filters // 4, 1) + out = conv_layer(out, filters // 4, 3) + out = conv_layer(out, filters // 4, 3, need_activate=False) + res = conv_layer(inputs, filters // 4, 1) + out = keras.layers.ELU()(out + res) + return out + + +def block_2(inputs, filters): + out = conv_layer(inputs, filters // 4, 1) + out = conv_layer(out, filters // 4, 3) + out = conv_layer(out, filters // 4, 3, need_activate=False) + res = conv_layer(inputs, filters // 4, 1) + out = keras.layers.ELU()(out + res) + return out + + +def block_3(inputs, filters): + out = conv_layer(inputs, filters // 4, 1) + out = conv_layer(out, filters // 4, 3) + out = conv_layer(out, filters, 3, need_activate=False) + res = conv_layer(inputs, filters, 1) + out = keras.layers.ELU()(out + res) + return out + + +def dr_unet(input_shape=(256, 256, 1), dims=32): + inputs = keras.Input(input_shape) + out = conv_layer(inputs, 16, 1) + + out = block_1(out, dims) + out_256 = block_3(out, dims) + out = keras.layers.MaxPool2D(2, 2, padding='same')(out_256) + + out = block_1(out, dims * 2) + out_128 = block_3(out, dims * 2) + out = keras.layers.MaxPool2D(2, 2, padding='same')(out_128) + + out = block_1(out, dims * 4) + out_64 = block_3(out, dims * 4) + out = keras.layers.MaxPool2D(2, 2, padding='same')(out_64) + + out = block_1(out, dims * 8) + out_32 = block_3(out, dims * 8) + out = keras.layers.MaxPool2D(2, 2, padding='same')(out_32) + + out = block_1(out, dims * 16) + out_16 = block_3(out, dims * 16) + out = keras.layers.MaxPool2D(2, 2, padding='same')(out_16) + + out = block_1(out, dims * 32) + out = block_3(out, dims * 32) + + up_16 = keras.layers.Conv2DTranspose(filters=dims * 16, kernel_size=2, strides=2, padding='same')(out) + up = keras.layers.Concatenate()([up_16, out_16]) + up = keras.layers.BatchNormalization()(up) + up = keras.layers.ELU()(up) + + up = block_2(up, dims * 16) + up = block_3(up, dims * 16) + up_32 = keras.layers.Conv2DTranspose(filters=dims * 8, kernel_size=2, strides=2, padding='same')(up) + up = keras.layers.Concatenate()([up_32, out_32]) + up = keras.layers.BatchNormalization()(up) + up = keras.layers.ELU()(up) + + up = block_2(up, dims * 8) + up = block_3(up, dims * 8) + up_64 = keras.layers.Conv2DTranspose(filters=dims * 4, kernel_size=2, strides=2, padding='same')(up) + up = keras.layers.Concatenate()([up_64, out_64]) + up = keras.layers.BatchNormalization()(up) + up = keras.layers.ELU()(up) + + up = block_2(up, dims * 4) + up = block_3(up, dims * 4) + up_128 = keras.layers.Conv2DTranspose(filters=dims * 2, kernel_size=2, strides=2, padding='same')(up) + up = keras.layers.Concatenate()([up_128, out_128]) + up = keras.layers.BatchNormalization()(up) + up = keras.layers.ELU()(up) + + up = block_2(up, dims * 2) + up = block_3(up, dims * 2) + up_256 = keras.layers.Conv2DTranspose(filters=dims * 1, kernel_size=2, strides=2, padding='same')(up) + up = keras.layers.Concatenate()([up_256, out_256]) + up = keras.layers.BatchNormalization()(up) + up = keras.layers.ELU()(up) + + up = block_2(up, dims) + up = block_3(up, dims) + up = keras.layers.Conv2D(filters=1, kernel_size=1, strides=(1, 1), padding='same')(up) + up = keras.activations.sigmoid(up) + return keras.Model(inputs, up) + + +if __name__ == '__main__': + model = dr_unet() + model.summary()