[beb348]: / model.py

Download this file

123 lines (90 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
from keras.models import Model,load_model
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers import Dropout,GaussianNoise, Input,Activation
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2DTranspose,UpSampling2D,concatenate,add
from keras.optimizers import SGD
import keras.backend as K
from losses import *
K.set_image_data_format("channels_last")
#u-net model
class Unet_model(object):
def __init__(self,img_shape,load_model_weights=None):
self.img_shape=img_shape
self.load_model_weights=load_model_weights
self.model =self.compile_unet()
def compile_unet(self):
"""
compile the U-net model
"""
i = Input(shape=self.img_shape)
#add gaussian noise to the first layer to combat overfitting
i_=GaussianNoise(0.01)(i)
i_ = Conv2D(64, 2, padding='same',data_format = 'channels_last')(i_)
out=self.unet(inputs=i_)
model = Model(input=i, output=out)
sgd = SGD(lr=0.08, momentum=0.9, decay=5e-6, nesterov=False)
model.compile(loss=gen_dice_loss, optimizer=sgd, metrics=[dice_whole_metric,dice_core_metric,dice_en_metric])
#load weights if set for prediction
if self.load_model_weights is not None:
model.load_weights(self.load_model_weights)
return model
def unet(self,inputs, nb_classes=4, start_ch=64, depth=3, inc_rate=2. ,activation='relu', dropout=0.0, batchnorm=True, upconv=True,format_='channels_last'):
"""
the actual u-net architecture
"""
o = self.level_block(inputs,start_ch, depth, inc_rate,activation, dropout, batchnorm, upconv,format_)
o = BatchNormalization()(o)
#o = Activation('relu')(o)
o=PReLU(shared_axes=[1, 2])(o)
o = Conv2D(nb_classes, 1, padding='same',data_format = format_)(o)
o = Activation('softmax')(o)
return o
def level_block(self,m, dim, depth, inc, acti, do, bn, up,format_="channels_last"):
if depth > 0:
n = self.res_block_enc(m,0.0,dim,acti, bn,format_)
#using strided 2D conv for donwsampling
m = Conv2D(int(inc*dim), 2,strides=2, padding='same',data_format = format_)(n)
m = self.level_block(m,int(inc*dim), depth-1, inc, acti, do, bn, up )
if up:
m = UpSampling2D(size=(2, 2),data_format = format_)(m)
m = Conv2D(dim, 2, padding='same',data_format = format_)(m)
else:
m = Conv2DTranspose(dim, 3, strides=2,padding='same',data_format = format_)(m)
n=concatenate([n,m])
#the decoding path
m = self.res_block_dec(n, 0.0,dim, acti, bn, format_)
else:
m = self.res_block_enc(m, 0.0,dim, acti, bn, format_)
return m
def res_block_enc(self,m, drpout,dim,acti, bn,format_="channels_last"):
"""
the encoding unit which a residual block
"""
n = BatchNormalization()(m) if bn else n
#n= Activation(acti)(n)
n=PReLU(shared_axes=[1, 2])(n)
n = Conv2D(dim, 3, padding='same',data_format = format_)(n)
n = BatchNormalization()(n) if bn else n
#n= Activation(acti)(n)
n=PReLU(shared_axes=[1, 2])(n)
n = Conv2D(dim, 3, padding='same',data_format =format_ )(n)
n=add([m,n])
return n
def res_block_dec(self,m, drpout,dim,acti, bn,format_="channels_last"):
"""
the decoding unit which a residual block
"""
n = BatchNormalization()(m) if bn else n
#n= Activation(acti)(n)
n=PReLU(shared_axes=[1, 2])(n)
n = Conv2D(dim, 3, padding='same',data_format = format_)(n)
n = BatchNormalization()(n) if bn else n
#n= Activation(acti)(n)
n=PReLU(shared_axes=[1, 2])(n)
n = Conv2D(dim, 3, padding='same',data_format =format_ )(n)
Save = Conv2D(dim, 1, padding='same',data_format = format_,use_bias=False)(m)
n=add([Save,n])
return n