a b/model.py
1
2
import numpy as np
3
from keras.models import Model,load_model
4
from keras.layers.advanced_activations import PReLU
5
from keras.layers.convolutional import Conv2D, MaxPooling2D
6
from keras.layers import Dropout,GaussianNoise, Input,Activation
7
from keras.layers.normalization import BatchNormalization
8
from keras.layers import  Conv2DTranspose,UpSampling2D,concatenate,add
9
from keras.optimizers import SGD
10
import keras.backend as K
11
from losses import *
12
13
K.set_image_data_format("channels_last")
14
15
 #u-net model
16
class Unet_model(object):
17
    
18
    def __init__(self,img_shape,load_model_weights=None):
19
        self.img_shape=img_shape
20
        self.load_model_weights=load_model_weights
21
        self.model =self.compile_unet()
22
        
23
    
24
    def compile_unet(self):
25
        """
26
        compile the U-net model
27
        """
28
        i = Input(shape=self.img_shape)
29
        #add gaussian noise to the first layer to combat overfitting
30
        i_=GaussianNoise(0.01)(i)
31
32
        i_ = Conv2D(64, 2, padding='same',data_format = 'channels_last')(i_)
33
        out=self.unet(inputs=i_)
34
        model = Model(input=i, output=out)
35
36
        sgd = SGD(lr=0.08, momentum=0.9, decay=5e-6, nesterov=False)
37
        model.compile(loss=gen_dice_loss, optimizer=sgd, metrics=[dice_whole_metric,dice_core_metric,dice_en_metric])
38
        #load weights if set for prediction
39
        if self.load_model_weights is not None:
40
            model.load_weights(self.load_model_weights)
41
        return model
42
43
44
    def unet(self,inputs, nb_classes=4, start_ch=64, depth=3, inc_rate=2. ,activation='relu', dropout=0.0, batchnorm=True, upconv=True,format_='channels_last'):
45
        """
46
        the actual u-net architecture
47
        """
48
        o = self.level_block(inputs,start_ch, depth, inc_rate,activation, dropout, batchnorm, upconv,format_)
49
        o = BatchNormalization()(o) 
50
        #o =  Activation('relu')(o)
51
        o=PReLU(shared_axes=[1, 2])(o)
52
        o = Conv2D(nb_classes, 1, padding='same',data_format = format_)(o)
53
        o = Activation('softmax')(o)
54
        return o
55
56
57
58
    def level_block(self,m, dim, depth, inc, acti, do, bn, up,format_="channels_last"):
59
        if depth > 0:
60
            n = self.res_block_enc(m,0.0,dim,acti, bn,format_)
61
            #using strided 2D conv for donwsampling
62
            m = Conv2D(int(inc*dim), 2,strides=2, padding='same',data_format = format_)(n)
63
            m = self.level_block(m,int(inc*dim), depth-1, inc, acti, do, bn, up )
64
            if up:
65
                m = UpSampling2D(size=(2, 2),data_format = format_)(m)
66
                m = Conv2D(dim, 2, padding='same',data_format = format_)(m)
67
            else:
68
                m = Conv2DTranspose(dim, 3, strides=2,padding='same',data_format = format_)(m)
69
            n=concatenate([n,m])
70
            #the decoding path
71
            m = self.res_block_dec(n, 0.0,dim, acti, bn, format_)
72
        else:
73
            m = self.res_block_enc(m, 0.0,dim, acti, bn, format_)
74
        return m
75
76
  
77
   
78
    def res_block_enc(self,m, drpout,dim,acti, bn,format_="channels_last"):
79
        
80
        """
81
        the encoding unit which a residual block
82
        """
83
        n = BatchNormalization()(m) if bn else n
84
        #n=  Activation(acti)(n)
85
        n=PReLU(shared_axes=[1, 2])(n)
86
        n = Conv2D(dim, 3, padding='same',data_format = format_)(n)
87
                
88
        n = BatchNormalization()(n) if bn else n
89
        #n=  Activation(acti)(n)
90
        n=PReLU(shared_axes=[1, 2])(n)
91
        n = Conv2D(dim, 3, padding='same',data_format =format_ )(n)
92
93
        n=add([m,n]) 
94
        
95
        return  n 
96
97
98
99
    def res_block_dec(self,m, drpout,dim,acti, bn,format_="channels_last"):
100
101
        """
102
        the decoding unit which a residual block
103
        """
104
         
105
        n = BatchNormalization()(m) if bn else n
106
        #n=  Activation(acti)(n)
107
        n=PReLU(shared_axes=[1, 2])(n)
108
        n = Conv2D(dim, 3, padding='same',data_format = format_)(n)
109
110
        n = BatchNormalization()(n) if bn else n
111
        #n=  Activation(acti)(n)
112
        n=PReLU(shared_axes=[1, 2])(n)
113
        n = Conv2D(dim, 3, padding='same',data_format =format_ )(n)
114
        
115
        Save = Conv2D(dim, 1, padding='same',data_format = format_,use_bias=False)(m) 
116
        n=add([Save,n]) 
117
        
118
        return  n   
119
120
121
122