from keras import backend as K
from keras.models import Model
from keras.layers import Input, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, concatenate, Concatenate, UpSampling2D, Activation
from keras.losses import categorical_crossentropy
from keras.applications.inception_resnet_v2 import InceptionResNetV2, inception_resnet_block, conv2d_bn
from keras.applications.densenet import DenseNet121, dense_block, transition_block
bn_axis = 3
channel_axis = bn_axis
def schedule_steps(epoch, steps):
for step in steps:
if step[1] > epoch:
print("Setting learning rate to {}".format(step[0]))
return step[0]
print("Setting learning rate to {}".format(steps[-1][0]))
return steps[-1][0]
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)
def dice_coef_loss(y_true, y_pred):
return 1 - (dice_coef(y_true, y_pred))
def softmax_dice_loss(y_true, y_pred):
return categorical_crossentropy(y_true, y_pred) * 0.5 + dice_coef_loss(y_true[..., 0], y_pred[..., 0]) * 0.3 + dice_coef_loss(y_true[..., 1], y_pred[..., 1]) * 0.2
def dice_coef_rounded_ch0(y_true, y_pred):
y_true_f = K.flatten(K.round(y_true[..., 0]))
y_pred_f = K.flatten(K.round(y_pred[..., 0]))
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)
def dice_coef_rounded_ch1(y_true, y_pred):
y_true_f = K.flatten(K.round(y_true[..., 1]))
y_pred_f = K.flatten(K.round(y_pred[..., 1]))
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)
def conv_block(prev, num_filters, kernel=(3, 3), strides=(1, 1), act='relu', prefix=None):
name = None
if prefix is not None:
name = prefix + '_conv'
conv = Conv2D(num_filters, kernel, padding='same', kernel_initializer='he_normal', strides=strides, name=name)(prev)
if prefix is not None:
name = prefix + '_norm'
conv = BatchNormalization(name=name, axis=bn_axis)(conv)
if prefix is not None:
name = prefix + '_act'
conv = Activation(act, name=name)(conv)
return conv
def get_densenet121_unet_softmax(input_shape, weights='imagenet'):
blocks = [6, 12, 24, 16]
img_input = Input(input_shape + (4,))
x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name='conv1/bn')(x)
x = Activation('relu', name='conv1/relu')(x)
conv1 = x
x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
x = MaxPooling2D(3, strides=2, name='pool1')(x)
x = dense_block(x, blocks[0], name='conv2')
conv2 = x
x = transition_block(x, 0.5, name='pool2')
x = dense_block(x, blocks[1], name='conv3')
conv3 = x
x = transition_block(x, 0.5, name='pool3')
x = dense_block(x, blocks[2], name='conv4')
conv4 = x
x = transition_block(x, 0.5, name='pool4')
x = dense_block(x, blocks[3], name='conv5')
x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name='bn')(x)
conv5 = x
conv6 = conv_block(UpSampling2D()(conv5), 320)
conv6 = concatenate([conv6, conv4], axis=-1)
conv6 = conv_block(conv6, 320)
conv7 = conv_block(UpSampling2D()(conv6), 256)
conv7 = concatenate([conv7, conv3], axis=-1)
conv7 = conv_block(conv7, 256)
conv8 = conv_block(UpSampling2D()(conv7), 128)
conv8 = concatenate([conv8, conv2], axis=-1)
conv8 = conv_block(conv8, 128)
conv9 = conv_block(UpSampling2D()(conv8), 96)
conv9 = concatenate([conv9, conv1], axis=-1)
conv9 = conv_block(conv9, 96)
conv10 = conv_block(UpSampling2D()(conv9), 64)
conv10 = conv_block(conv10, 64)
res = Conv2D(3, (1, 1), activation='softmax')(conv10)
model = Model(img_input, res)
if weights == 'imagenet':
densenet = DenseNet121(input_shape=input_shape + (3,), weights=weights, include_top=False)
w0 = densenet.layers[2].get_weights()
w = model.layers[2].get_weights()
w[0][:, :, [0, 1, 2], :] = 0.9 * w0[0][:, :, :3, :]
w[0][:, :, 3, :] = 0.1 * w0[0][:, :, 1, :]
model.layers[2].set_weights(w)
for i in range(3, len(densenet.layers)):
model.layers[i].set_weights(densenet.layers[i].get_weights())
model.layers[i].trainable = False
return model
def get_inception_resnet_v2_unet_softmax(input_shape, weights='imagenet'):
inp = Input(input_shape + (4,))
# Stem block: 35 x 35 x 192
x = conv2d_bn(inp, 32, 3, strides=2, padding='same')
x = conv2d_bn(x, 32, 3, padding='same')
x = conv2d_bn(x, 64, 3)
conv1 = x
x = MaxPooling2D(3, strides=2, padding='same')(x)
x = conv2d_bn(x, 80, 1, padding='same')
x = conv2d_bn(x, 192, 3, padding='same')
conv2 = x
x = MaxPooling2D(3, strides=2, padding='same')(x)
# Mixed 5b (Inception-A block): 35 x 35 x 320
branch_0 = conv2d_bn(x, 96, 1)
branch_1 = conv2d_bn(x, 48, 1)
branch_1 = conv2d_bn(branch_1, 64, 5)
branch_2 = conv2d_bn(x, 64, 1)
branch_2 = conv2d_bn(branch_2, 96, 3)
branch_2 = conv2d_bn(branch_2, 96, 3)
branch_pool = AveragePooling2D(3, strides=1, padding='same')(x)
branch_pool = conv2d_bn(branch_pool, 64, 1)
branches = [branch_0, branch_1, branch_2, branch_pool]
channel_axis = 1 if K.image_data_format() == 'channels_first' else 3
x = Concatenate(axis=channel_axis, name='mixed_5b')(branches)
# 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
for block_idx in range(1, 11):
x = inception_resnet_block(x,
scale=0.17,
block_type='block35',
block_idx=block_idx)
conv3 = x
# Mixed 6a (Reduction-A block): 17 x 17 x 1088
branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='same')
branch_1 = conv2d_bn(x, 256, 1)
branch_1 = conv2d_bn(branch_1, 256, 3)
branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='same')
branch_pool = MaxPooling2D(3, strides=2, padding='same')(x)
branches = [branch_0, branch_1, branch_pool]
x = Concatenate(axis=channel_axis, name='mixed_6a')(branches)
# 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
for block_idx in range(1, 21):
x = inception_resnet_block(x,
scale=0.1,
block_type='block17',
block_idx=block_idx)
conv4 = x
# Mixed 7a (Reduction-B block): 8 x 8 x 2080
branch_0 = conv2d_bn(x, 256, 1)
branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='same')
branch_1 = conv2d_bn(x, 256, 1)
branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='same')
branch_2 = conv2d_bn(x, 256, 1)
branch_2 = conv2d_bn(branch_2, 288, 3)
branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='same')
branch_pool = MaxPooling2D(3, strides=2, padding='same')(x)
branches = [branch_0, branch_1, branch_2, branch_pool]
x = Concatenate(axis=channel_axis, name='mixed_7a')(branches)
# 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
for block_idx in range(1, 10):
x = inception_resnet_block(x,
scale=0.2,
block_type='block8',
block_idx=block_idx)
x = inception_resnet_block(x,
scale=1.,
activation=None,
block_type='block8',
block_idx=10)
# Final convolution block: 8 x 8 x 1536
x = conv2d_bn(x, 1536, 1, name='conv_7b')
conv5 = x
conv6 = conv_block(UpSampling2D()(conv5), 320)
conv6 = concatenate([conv6, conv4], axis=-1)
conv6 = conv_block(conv6, 320)
conv7 = conv_block(UpSampling2D()(conv6), 256)
conv7 = concatenate([conv7, conv3], axis=-1)
conv7 = conv_block(conv7, 256)
conv8 = conv_block(UpSampling2D()(conv7), 128)
conv8 = concatenate([conv8, conv2], axis=-1)
conv8 = conv_block(conv8, 128)
conv9 = conv_block(UpSampling2D()(conv8), 96)
conv9 = concatenate([conv9, conv1], axis=-1)
conv9 = conv_block(conv9, 96)
conv10 = conv_block(UpSampling2D()(conv9), 64)
conv10 = conv_block(conv10, 64)
res = Conv2D(3, (1, 1), activation='softmax')(conv10)
model = Model(inp, res)
if weights == 'imagenet':
inception_resnet_v2 = InceptionResNetV2(weights=weights, include_top=False, input_shape=input_shape + (3,))
for i in range(2, len(inception_resnet_v2.layers)-1):
model.layers[i].set_weights(inception_resnet_v2.layers[i].get_weights())
model.layers[i].trainable = False
return model