Diff of /unet.py [000000] .. [48d89d]

Switch to unified view

a b/unet.py
1
"""
2
This code is to build and train 2D U-Net
3
"""
4
import numpy as np
5
import sys
6
import subprocess
7
import argparse
8
import os
9
10
from keras.models import Model
11
from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, ZeroPadding2D, add
12
from keras.optimizers import Adam, SGD
13
from keras.callbacks import ModelCheckpoint, CSVLogger
14
from keras import backend as K
15
from keras import losses
16
17
import tensorflow as tf
18
import matplotlib.pyplot as plt
19
import pandas as pd
20
import csv
21
22
from utils import *
23
from data import load_train_data
24
25
K.set_image_data_format('channels_last')  # Tensorflow dimension ordering
26
27
# ----- paths setting -----
28
data_path = sys.argv[1] + "/"
29
model_path = data_path + "models/"
30
log_path = data_path + "logs/"
31
32
33
# ----- params for training and testing -----
34
batch_size = 1
35
cur_fold = sys.argv[2]
36
plane = sys.argv[3]
37
epoch = int(sys.argv[4])
38
init_lr = float(sys.argv[5])
39
40
41
# ----- Dice Coefficient and cost function for training -----
42
smooth = 1.
43
44
def dice_coef(y_true, y_pred):
45
    y_true_f = K.flatten(y_true)
46
    y_pred_f = K.flatten(y_pred)
47
    intersection = K.sum(y_true_f * y_pred_f)
48
    return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
49
50
def dice_coef_loss(y_true, y_pred):
51
    return  -dice_coef(y_true, y_pred)
52
53
54
def get_unet((img_rows, img_cols), flt=64, pool_size=(2, 2, 2), init_lr=1.0e-5):
55
    """build and compile Neural Network"""
56
57
    print "start building NN"
58
    inputs = Input((img_rows, img_cols, 1))
59
60
    conv1 = Conv2D(flt, (3, 3), activation='relu', padding='same')(inputs)
61
    conv1 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv1)
62
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
63
64
    conv2 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(pool1)
65
    conv2 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(conv2)
66
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
67
68
    conv3 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(pool2)
69
    conv3 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(conv3)
70
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
71
72
    conv4 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(pool3)
73
    conv4 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(conv4)
74
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
75
76
    conv5 = Conv2D(flt*16, (3, 3), activation='relu', padding='same')(pool4)
77
    conv5 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(conv5)
78
79
    up6 = concatenate([Conv2DTranspose(flt*8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
80
    conv6 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(up6)
81
    conv6 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(conv6)
82
83
    up7 = concatenate([Conv2DTranspose(flt*4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
84
    conv7 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(up7)
85
    conv7 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(conv7)
86
87
    up8 = concatenate([Conv2DTranspose(flt*2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
88
    conv8 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(up8)
89
    conv8 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv8)
90
91
    up9 = concatenate([Conv2DTranspose(flt, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
92
    conv9 = Conv2D(flt, (3, 3), activation='relu', padding='same')(up9)
93
    conv9 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv9)
94
95
    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
96
97
    model = Model(inputs=[inputs], outputs=[conv10])
98
99
    model.compile(optimizer=Adam(lr=init_lr), loss=dice_coef_loss, metrics=[dice_coef])
100
101
    return model
102
103
104
def train(fold, plane, batch_size, nb_epoch,init_lr):
105
    """
106
    train an Unet model with data from load_train_data()
107
108
    Parameters
109
    ----------
110
    fold : string
111
        which fold is experimenting in 4-fold. It should be one of 0/1/2/3
112
113
    plane : char
114
        which plane is experimenting. It is from 'X'/'Y'/'Z'
115
116
    batch_size : int
117
        size of mini-batch
118
119
    nb_epoch : int
120
        number of epochs to train NN
121
122
    init_lr : float
123
        initial learning rate
124
    """
125
126
    print "number of epoch: ", nb_epoch
127
    print "learning rate: ", init_lr
128
129
    # --------------------- load and preprocess training data -----------------
130
    print '-'*80
131
    print '         Loading and preprocessing train data...'
132
    print '-'*80
133
134
    imgs_train, imgs_mask_train = load_train_data(fold, plane)
135
136
    imgs_row = imgs_train.shape[1]
137
    imgs_col = imgs_train.shape[2]
138
139
    imgs_train = preprocess(imgs_train)
140
    imgs_mask_train = preprocess(imgs_mask_train)
141
142
    imgs_train = imgs_train.astype('float32')
143
    imgs_mask_train = imgs_mask_train.astype('float32')
144
145
    # ---------------------- Create, compile, and train model ------------------------
146
    print '-'*80
147
    print '     Creating and compiling model...'
148
    print '-'*80
149
150
    model = get_unet((imgs_row, imgs_col), pool_size=(2, 2, 2), init_lr=init_lr)
151
    print model.summary()
152
153
    print '-'*80
154
    print '     Fitting model...'
155
    print '-'*80
156
157
    ver = 'unet_fd%s_%s_ep%s_lr%s.csv'%(cur_fold, plane, epoch, init_lr)
158
    csv_logger = CSVLogger(log_path + ver)
159
    model_checkpoint = ModelCheckpoint(model_path + ver + ".h5",
160
                                       monitor='loss',
161
                                       save_best_only=False,
162
                                       period=10)
163
164
    history = model.fit(imgs_train, imgs_mask_train,
165
                        batch_size= batch_size, epochs= nb_epoch, verbose=1, shuffle=True,
166
                        callbacks=[model_checkpoint, csv_logger])
167
168
169
if __name__ == "__main__":
170
171
    train(cur_fold, plane, batch_size, epoch, init_lr)
172
173
    print "training done"