|
a |
|
b/model.py |
|
|
1 |
|
|
|
2 |
import numpy as np |
|
|
3 |
from keras.models import Model,load_model |
|
|
4 |
from keras.layers.advanced_activations import PReLU |
|
|
5 |
from keras.layers.convolutional import Conv2D, MaxPooling2D |
|
|
6 |
from keras.layers import Dropout,GaussianNoise, Input,Activation |
|
|
7 |
from keras.layers.normalization import BatchNormalization |
|
|
8 |
from keras.layers import Conv2DTranspose,UpSampling2D,concatenate,add |
|
|
9 |
from keras.optimizers import SGD |
|
|
10 |
import keras.backend as K |
|
|
11 |
from losses import * |
|
|
12 |
|
|
|
13 |
K.set_image_data_format("channels_last") |
|
|
14 |
|
|
|
15 |
#u-net model |
|
|
16 |
class Unet_model(object): |
|
|
17 |
|
|
|
18 |
def __init__(self,img_shape,load_model_weights=None): |
|
|
19 |
self.img_shape=img_shape |
|
|
20 |
self.load_model_weights=load_model_weights |
|
|
21 |
self.model =self.compile_unet() |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def compile_unet(self): |
|
|
25 |
""" |
|
|
26 |
compile the U-net model |
|
|
27 |
""" |
|
|
28 |
i = Input(shape=self.img_shape) |
|
|
29 |
#add gaussian noise to the first layer to combat overfitting |
|
|
30 |
i_=GaussianNoise(0.01)(i) |
|
|
31 |
|
|
|
32 |
i_ = Conv2D(64, 2, padding='same',data_format = 'channels_last')(i_) |
|
|
33 |
out=self.unet(inputs=i_) |
|
|
34 |
model = Model(input=i, output=out) |
|
|
35 |
|
|
|
36 |
sgd = SGD(lr=0.08, momentum=0.9, decay=5e-6, nesterov=False) |
|
|
37 |
model.compile(loss=gen_dice_loss, optimizer=sgd, metrics=[dice_whole_metric,dice_core_metric,dice_en_metric]) |
|
|
38 |
#load weights if set for prediction |
|
|
39 |
if self.load_model_weights is not None: |
|
|
40 |
model.load_weights(self.load_model_weights) |
|
|
41 |
return model |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def unet(self,inputs, nb_classes=4, start_ch=64, depth=3, inc_rate=2. ,activation='relu', dropout=0.0, batchnorm=True, upconv=True,format_='channels_last'): |
|
|
45 |
""" |
|
|
46 |
the actual u-net architecture |
|
|
47 |
""" |
|
|
48 |
o = self.level_block(inputs,start_ch, depth, inc_rate,activation, dropout, batchnorm, upconv,format_) |
|
|
49 |
o = BatchNormalization()(o) |
|
|
50 |
#o = Activation('relu')(o) |
|
|
51 |
o=PReLU(shared_axes=[1, 2])(o) |
|
|
52 |
o = Conv2D(nb_classes, 1, padding='same',data_format = format_)(o) |
|
|
53 |
o = Activation('softmax')(o) |
|
|
54 |
return o |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def level_block(self,m, dim, depth, inc, acti, do, bn, up,format_="channels_last"): |
|
|
59 |
if depth > 0: |
|
|
60 |
n = self.res_block_enc(m,0.0,dim,acti, bn,format_) |
|
|
61 |
#using strided 2D conv for donwsampling |
|
|
62 |
m = Conv2D(int(inc*dim), 2,strides=2, padding='same',data_format = format_)(n) |
|
|
63 |
m = self.level_block(m,int(inc*dim), depth-1, inc, acti, do, bn, up ) |
|
|
64 |
if up: |
|
|
65 |
m = UpSampling2D(size=(2, 2),data_format = format_)(m) |
|
|
66 |
m = Conv2D(dim, 2, padding='same',data_format = format_)(m) |
|
|
67 |
else: |
|
|
68 |
m = Conv2DTranspose(dim, 3, strides=2,padding='same',data_format = format_)(m) |
|
|
69 |
n=concatenate([n,m]) |
|
|
70 |
#the decoding path |
|
|
71 |
m = self.res_block_dec(n, 0.0,dim, acti, bn, format_) |
|
|
72 |
else: |
|
|
73 |
m = self.res_block_enc(m, 0.0,dim, acti, bn, format_) |
|
|
74 |
return m |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
|
|
|
78 |
def res_block_enc(self,m, drpout,dim,acti, bn,format_="channels_last"): |
|
|
79 |
|
|
|
80 |
""" |
|
|
81 |
the encoding unit which a residual block |
|
|
82 |
""" |
|
|
83 |
n = BatchNormalization()(m) if bn else n |
|
|
84 |
#n= Activation(acti)(n) |
|
|
85 |
n=PReLU(shared_axes=[1, 2])(n) |
|
|
86 |
n = Conv2D(dim, 3, padding='same',data_format = format_)(n) |
|
|
87 |
|
|
|
88 |
n = BatchNormalization()(n) if bn else n |
|
|
89 |
#n= Activation(acti)(n) |
|
|
90 |
n=PReLU(shared_axes=[1, 2])(n) |
|
|
91 |
n = Conv2D(dim, 3, padding='same',data_format =format_ )(n) |
|
|
92 |
|
|
|
93 |
n=add([m,n]) |
|
|
94 |
|
|
|
95 |
return n |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def res_block_dec(self,m, drpout,dim,acti, bn,format_="channels_last"): |
|
|
100 |
|
|
|
101 |
""" |
|
|
102 |
the decoding unit which a residual block |
|
|
103 |
""" |
|
|
104 |
|
|
|
105 |
n = BatchNormalization()(m) if bn else n |
|
|
106 |
#n= Activation(acti)(n) |
|
|
107 |
n=PReLU(shared_axes=[1, 2])(n) |
|
|
108 |
n = Conv2D(dim, 3, padding='same',data_format = format_)(n) |
|
|
109 |
|
|
|
110 |
n = BatchNormalization()(n) if bn else n |
|
|
111 |
#n= Activation(acti)(n) |
|
|
112 |
n=PReLU(shared_axes=[1, 2])(n) |
|
|
113 |
n = Conv2D(dim, 3, padding='same',data_format =format_ )(n) |
|
|
114 |
|
|
|
115 |
Save = Conv2D(dim, 1, padding='same',data_format = format_,use_bias=False)(m) |
|
|
116 |
n=add([Save,n]) |
|
|
117 |
|
|
|
118 |
return n |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
|
|
|
122 |
|