Switch to unified view

a b/DigiPathAI/models/densenet.py
1
"""
2
    github cite: 
3
"""
4
5
from __future__ import absolute_import
6
from __future__ import division
7
from __future__ import print_function
8
9
from datetime import datetime
10
import os
11
import glob
12
import random
13
14
import imgaug
15
from imgaug import augmenters as iaa
16
from PIL import Image
17
from tqdm import tqdm
18
import matplotlib.pyplot as plt
19
20
21
import numpy as np 
22
import tensorflow as tf
23
from tensorflow.keras import backend as K
24
from tensorflow.keras.models import Model
25
from tensorflow.keras.layers import (Input, BatchNormalization, Conv2D, MaxPooling2D,                                                   AveragePooling2D, ZeroPadding2D, concatenate,   
26
                    Concatenate, UpSampling2D, Activation, Lambda)
27
from tensorflow.keras.losses import categorical_crossentropy
28
from tensorflow.keras.optimizers import Adam
29
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard
30
from tensorflow.keras import metrics
31
32
33
# Densenet Model
34
bn_axis = 3
35
channel_axis = bn_axis
36
37
def conv_block(prev, num_filters, kernel=(3, 3), strides=(1, 1), act='relu', prefix=None):
38
    name = None
39
    if prefix is not None:
40
        name = prefix + '_conv'
41
    conv = Conv2D(num_filters, kernel, padding='same', kernel_initializer='he_normal', strides=strides, name=name)(prev)
42
    if prefix is not None:
43
        name = prefix + '_norm'
44
    conv = BatchNormalization(name=name, axis=bn_axis)(conv)
45
    if prefix is not None:
46
        name = prefix + '_act'
47
    conv = Activation(act, name=name)(conv)
48
    return conv
49
50
def dense_conv_block(x, growth_rate, name):
51
    """A building block for a dense block.
52
    # Arguments
53
        x: input tensor.
54
        growth_rate: float, growth rate at dense layers.
55
        name: string, block label.
56
    # Returns
57
        Output tensor for the block.
58
    """
59
    bn_axis = 3
60
    x1 = BatchNormalization(axis=bn_axis,
61
                                   epsilon=1.001e-5,
62
                                   name=name + '_0_bn')(x)
63
    x1 = Activation('relu', name=name + '_0_relu')(x1)
64
    x1 = Conv2D(4 * growth_rate, 1,
65
                       use_bias=False,
66
                       name=name + '_1_conv')(x1)
67
    x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
68
                                   name=name + '_1_bn')(x1)
69
    x1 = Activation('relu', name=name + '_1_relu')(x1)
70
    x1 = Conv2D(growth_rate, 3,
71
                       padding='same',
72
                       use_bias=False,
73
                       name=name + '_2_conv')(x1)
74
    x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
75
    return x
76
77
def dense_block(x, blocks, name):
78
    """A dense block.
79
    # Arguments
80
        x: input tensor.
81
        blocks: integer, the number of building blocks.
82
        name: string, block label.
83
    # Returns
84
        output tensor for the block.
85
    """
86
    for i in range(blocks):
87
        x = dense_conv_block(x, 32, name=name + '_block' + str(i + 1))
88
    return x
89
90
91
def transition_block(x, reduction, name):
92
    """A transition block.
93
    # Arguments
94
        x: input tensor.
95
        reduction: float, compression rate at transition layers.
96
        name: string, block label.
97
    # Returns
98
        output tensor for the block.
99
    """
100
    bn_axis = 3
101
    x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
102
                                  name=name + '_bn')(x)
103
    x = Activation('relu', name=name + '_relu')(x)
104
    x = Conv2D(int(K.int_shape(x)[bn_axis] * reduction), 1,
105
                      use_bias=False,
106
                      name=name + '_conv')(x)
107
    x = AveragePooling2D(2, strides=2, name=name + '_pool')(x)
108
    return x
109
110
def unet_densenet121(input_shape, weights='imagenet'):
111
    blocks = [6, 12, 24, 16]
112
    n_channel = 3
113
    n_class = 2
114
    img_input = Input(input_shape + (n_channel,))
115
    
116
    x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
117
    x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
118
    x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
119
                           name='conv1/bn')(x)
120
    x = Activation('relu', name='conv1/relu')(x)
121
    conv1 = x
122
    x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
123
    x = MaxPooling2D(3, strides=2, name='pool1')(x)
124
    x = dense_block(x, blocks[0], name='conv2')
125
    conv2 = x
126
    x = transition_block(x, 0.5, name='pool2')
127
    x = dense_block(x, blocks[1], name='conv3')
128
    conv3 = x
129
    x = transition_block(x, 0.5, name='pool3')
130
    x = dense_block(x, blocks[2], name='conv4')
131
    conv4 = x
132
    x = transition_block(x, 0.5, name='pool4')
133
    x = dense_block(x, blocks[3], name='conv5')
134
    x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
135
                           name='bn')(x)
136
    conv5 = x 
137
    
138
    conv6 = conv_block(UpSampling2D()(conv5), 320)
139
    conv6 = concatenate([conv6, conv4], axis=-1)
140
    conv6 = conv_block(conv6, 320)
141
142
    conv7 = conv_block(UpSampling2D()(conv6), 256)
143
    conv7 = concatenate([conv7, conv3], axis=-1)
144
    conv7 = conv_block(conv7, 256)
145
146
    conv8 = conv_block(UpSampling2D()(conv7), 128)
147
    conv8 = concatenate([conv8, conv2], axis=-1)
148
    conv8 = conv_block(conv8, 128)
149
150
    conv9 = conv_block(UpSampling2D()(conv8), 96)
151
    conv9 = concatenate([conv9, conv1], axis=-1)
152
    conv9 = conv_block(conv9, 96)
153
154
    conv10 = conv_block(UpSampling2D()(conv9), 64)
155
    conv10 = conv_block(conv10, 64)
156
    res = Conv2D(n_class, (1, 1), activation='softmax')(conv10)
157
    model = Model(img_input, res)
158
159
    return model
160
#model = unet_densenet121(input_shape=(256,256), weights=None)
161
#model.summary()
162
163
164
# In[6]:
165