[606337]: / drunet / model / dr_unet.py

Download this file

110 lines (85 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
from tensorflow import keras
import tensorflow as tf
def conv_layer(inputs, filters, kernel_size=3, strides=1, need_activate=True):
out = keras.layers.Conv2D(filters, kernel_size, strides, padding='same')(inputs)
out = keras.layers.BatchNormalization()(out)
if need_activate:
out = keras.layers.ELU()(out)
return out
def block_1(inputs, filters):
out = conv_layer(inputs, filters // 4, 1)
out = conv_layer(out, filters // 4, 3)
out = conv_layer(out, filters // 4, 3, need_activate=False)
res = conv_layer(inputs, filters // 4, 1)
out = keras.layers.ELU()(out + res)
return out
def block_2(inputs, filters):
out = conv_layer(inputs, filters // 4, 1)
out = conv_layer(out, filters // 4, 3)
out = conv_layer(out, filters // 4, 3, need_activate=False)
res = conv_layer(inputs, filters // 4, 1)
out = keras.layers.ELU()(out + res)
return out
def block_3(inputs, filters):
out = conv_layer(inputs, filters // 4, 1)
out = conv_layer(out, filters // 4, 3)
out = conv_layer(out, filters, 3, need_activate=False)
res = conv_layer(inputs, filters, 1)
out = keras.layers.ELU()(out + res)
return out
def dr_unet(input_shape=(256, 256, 1), dims=32):
inputs = keras.Input(input_shape)
out = conv_layer(inputs, 16, 1)
out = block_1(out, dims)
out_256 = block_3(out, dims)
out = keras.layers.MaxPool2D(2, 2, padding='same')(out_256)
out = block_1(out, dims * 2)
out_128 = block_3(out, dims * 2)
out = keras.layers.MaxPool2D(2, 2, padding='same')(out_128)
out = block_1(out, dims * 4)
out_64 = block_3(out, dims * 4)
out = keras.layers.MaxPool2D(2, 2, padding='same')(out_64)
out = block_1(out, dims * 8)
out_32 = block_3(out, dims * 8)
out = keras.layers.MaxPool2D(2, 2, padding='same')(out_32)
out = block_1(out, dims * 16)
out_16 = block_3(out, dims * 16)
out = keras.layers.MaxPool2D(2, 2, padding='same')(out_16)
out = block_1(out, dims * 32)
out = block_3(out, dims * 32)
up_16 = keras.layers.Conv2DTranspose(filters=dims * 16, kernel_size=2, strides=2, padding='same')(out)
up = keras.layers.Concatenate()([up_16, out_16])
up = keras.layers.BatchNormalization()(up)
up = keras.layers.ELU()(up)
up = block_2(up, dims * 16)
up = block_3(up, dims * 16)
up_32 = keras.layers.Conv2DTranspose(filters=dims * 8, kernel_size=2, strides=2, padding='same')(up)
up = keras.layers.Concatenate()([up_32, out_32])
up = keras.layers.BatchNormalization()(up)
up = keras.layers.ELU()(up)
up = block_2(up, dims * 8)
up = block_3(up, dims * 8)
up_64 = keras.layers.Conv2DTranspose(filters=dims * 4, kernel_size=2, strides=2, padding='same')(up)
up = keras.layers.Concatenate()([up_64, out_64])
up = keras.layers.BatchNormalization()(up)
up = keras.layers.ELU()(up)
up = block_2(up, dims * 4)
up = block_3(up, dims * 4)
up_128 = keras.layers.Conv2DTranspose(filters=dims * 2, kernel_size=2, strides=2, padding='same')(up)
up = keras.layers.Concatenate()([up_128, out_128])
up = keras.layers.BatchNormalization()(up)
up = keras.layers.ELU()(up)
up = block_2(up, dims * 2)
up = block_3(up, dims * 2)
up_256 = keras.layers.Conv2DTranspose(filters=dims * 1, kernel_size=2, strides=2, padding='same')(up)
up = keras.layers.Concatenate()([up_256, out_256])
up = keras.layers.BatchNormalization()(up)
up = keras.layers.ELU()(up)
up = block_2(up, dims)
up = block_3(up, dims)
up = keras.layers.Conv2D(filters=1, kernel_size=1, strides=(1, 1), padding='same')(up)
up = keras.activations.sigmoid(up)
return keras.Model(inputs, up)
if __name__ == '__main__':
model = dr_unet()
model.summary()