a b/train.py
1
import numpy as np
2
import tf_models
3
from sklearn.preprocessing import scale
4
import tensorflow as tf
5
from tensorflow.contrib.keras.python.keras.backend import learning_phase
6
from tensorflow.contrib.keras.python.keras.layers import concatenate, Conv3D
7
from nibabel import load as load_nii
8
import os
9
import argparse
10
import keras
11
12
13
def parse_inputs():
14
15
    parser = argparse.ArgumentParser(description='train the model')
16
    parser.add_argument('-r', '--root-path', dest='root_path', default='/media/lele/Data/spie/Brats17TrainingData/HGG')
17
    parser.add_argument('-sp', '--save-path', dest='save_path', default='dense24_correction')
18
    parser.add_argument('-lp', '--load-path', dest='load_path', default='dense24_correction')
19
    parser.add_argument('-ow', '--offset-width', dest='offset_w', type=int, default=12)
20
    parser.add_argument('-oh', '--offset-height', dest='offset_h', type=int, default=12)
21
    parser.add_argument('-oc', '--offset-channel', dest='offset_c', nargs='+', type=int, default=12)
22
    parser.add_argument('-ws', '--width-size', dest='wsize', type=int, default=38)
23
    parser.add_argument('-hs', '--height-size', dest='hsize', type=int, default=38)
24
    parser.add_argument('-cs', '--channel-size', dest='csize', type=int, default=38)
25
    parser.add_argument('-ps', '--pred-size', dest='psize', type=int, default=12)
26
    parser.add_argument('-bs', '--batch-size', dest='batch_size', type=int, default=2)
27
    parser.add_argument('-e', '--num-epochs', dest='num_epochs', type=int, default=5)
28
    parser.add_argument('-c', '--continue-training', dest='continue_training', type=bool, default=False)
29
    parser.add_argument('-mn', '--model_name', dest='model_name', type=str, default='dense24')
30
    parser.add_argument('-nc', '--n4correction', dest='correction', type=bool, default=False)
31
    parser.add_argument('-gpu', '--gpu_id', dest='gpu_id', type=str, default='0')
32
    return vars(parser.parse_args())
33
34
options = parse_inputs()
35
36
os.environ["CUDA_VISIBLE_DEVICES"] = options['gpu_id']
37
def acc_tf(y_pred, y_true):
38
    correct_prediction = tf.equal(tf.cast(tf.argmax(y_pred, -1), tf.int32), tf.cast(tf.argmax(y_true, -1), tf.int32))
39
    return 100 * tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
40
41
42
def get_patches_3d(data, labels, centers, hsize, wsize, csize, psize, preprocess=True):
43
    """
44
45
    :param data: 4D nparray (h, w, c, ?)
46
    :param centers:
47
    :param hsize:
48
    :param wsize:
49
    :param csize:
50
    :return:
51
    """
52
    patches_x, patches_y = [], []
53
    offset_p = (hsize - psize) / 2
54
    for i in range(len(centers[0])):
55
        h, w, c = centers[0, i], centers[1, i], centers[2, i]
56
        h_beg = min(max(0, h - hsize / 2), 240 - hsize)
57
        w_beg = min(max(0, w - wsize / 2), 240 - wsize)
58
        c_beg = min(max(0, c - csize / 2), 155 - csize)
59
        ph_beg = h_beg + offset_p
60
        pw_beg = w_beg + offset_p
61
        pc_beg = c_beg + offset_p
62
        vox = data[h_beg:h_beg + hsize, w_beg:w_beg + wsize, c_beg:c_beg + csize, :]
63
        vox_labels = labels[ph_beg:ph_beg + psize, pw_beg:pw_beg + psize, pc_beg:pc_beg + psize]
64
        patches_x.append(vox)
65
        patches_y.append(vox_labels)
66
    return np.array(patches_x), np.array(patches_y)
67
68
69
def positive_ratio(x):
70
    return float(np.sum(np.greater(x, 0))) / np.prod(x.shape)
71
72
73
def norm(image):
74
    image = np.squeeze(image)
75
    image_nonzero = image[np.nonzero(image)]
76
    return (image - image_nonzero.mean()) / image_nonzero.std()
77
78
79
def segmentation_loss(y_true, y_pred, n_classes):
80
    y_true = tf.reshape(y_true, (-1, n_classes))
81
    y_pred = tf.reshape(y_pred, (-1, n_classes))
82
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,
83
                                                                  logits=y_pred))
84
85
86
def vox_preprocess(vox):
87
    vox_shape = vox.shape
88
    vox = np.reshape(vox, (-1, vox_shape[-1]))
89
    vox = scale(vox, axis=0)
90
    return np.reshape(vox, vox_shape)
91
92
93
def one_hot(y, num_classes):
94
    y_ = np.zeros([len(y), num_classes])
95
    y_[np.arange(len(y)), y] = 1
96
    return y_
97
98
99
def dice_coef_np(y_true, y_pred, num_classes):
100
    """
101
102
    :param y_true: sparse labels
103
    :param y_pred: sparse labels
104
    :param num_classes: number of classes
105
    :return:
106
    """
107
    y_true = y_true.astype(int)
108
    y_pred = y_pred.astype(int)
109
    y_true = y_true.flatten()
110
    y_true = one_hot(y_true, num_classes)
111
    y_pred = y_pred.flatten()
112
    y_pred = one_hot(y_pred, num_classes)
113
    intersection = np.sum(y_true * y_pred, axis=0)
114
    return (2. * intersection) / (np.sum(y_true, axis=0) + np.sum(y_pred, axis=0))
115
116
117
def vox_generator(all_files, n_pos, n_neg,correction= False):
118
    path = options['root_path']
119
    while 1:
120
        for file in all_files:
121
            if correction:
122
                flair = load_nii(os.path.join(path, file, file + '_flair_corrected.nii.gz')).get_data()
123
                t2 = load_nii(os.path.join(path, file, file + '_t2_corrected.nii.gz')).get_data()
124
                t1 = load_nii(os.path.join(path, file, file + '_t1_corrected.nii.gz')).get_data()
125
                t1ce = load_nii(os.path.join(path, file, file + '_t1ce_corrected.nii.gz')).get_data()
126
            else:
127
128
                flair = load_nii(os.path.join(path, file, file + '_flair.nii.gz')).get_data()
129
                t2 = load_nii(os.path.join(path, file, file + '_t2.nii.gz')).get_data()
130
                t1 = load_nii(os.path.join(path, file, file + '_t1.nii.gz')).get_data()
131
                t1ce = load_nii(os.path.join(path, file, file + '_t1ce.nii.gz')).get_data()
132
133
            data_norm = np.array([norm(flair), norm(t2), norm(t1), norm(t1ce)])
134
            data_norm = np.transpose(data_norm, axes=[1, 2, 3, 0])
135
            labels = load_nii(os.path.join(path, file, file+'_seg.nii.gz')).get_data()
136
137
            foreground = np.array(np.where(labels > 0))
138
            background = np.array(np.where((labels == 0) & (flair > 0)))
139
140
            # n_pos = int(foreground.shape[1] * discount)
141
            foreground = foreground[:, np.random.permutation(foreground.shape[1])[:n_pos]]
142
            background = background[:, np.random.permutation(background.shape[1])[:n_neg]]
143
144
            centers = np.concatenate((foreground, background), axis=1)
145
            centers = centers[:, np.random.permutation(n_neg+n_pos)]
146
147
            yield data_norm, labels, centers
148
149
150
def label_transform(y, nlabels):
151
    return [
152
            keras.utils.to_categorical(np.copy(y).astype(dtype=np.bool),
153
                                       num_classes=2).reshape([y.shape[0], y.shape[1], y.shape[2], y.shape[3], 2]),
154
155
            keras.utils.to_categorical(y,
156
                                       num_classes=nlabels).reshape([y.shape[0], y.shape[1], y.shape[2], y.shape[3], nlabels])
157
            ]
158
159
160
def train():
161
    NUM_EPOCHS = options['num_epochs']
162
    LOAD_PATH = options['load_path']
163
    SAVE_PATH = options['save_path']
164
    PSIZE = options['psize']
165
    HSIZE = options['hsize']
166
    WSIZE = options['wsize']
167
    CSIZE = options['csize']
168
    model_name= options['model_name']
169
    BATCH_SIZE = options['batch_size']
170
    continue_training = options['continue_training']
171
172
    files = []
173
    num_labels = 5
174
    with open('train.txt') as f:
175
        for line in f:
176
            files.append(line[:-1])
177
    print '%d training samples' % len(files)
178
179
    flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
180
    t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2))
181
    flair_t2_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 2))
182
    t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 5))
183
184
    if model_name == 'dense48':
185
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair')
186
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1')
187
    elif model_name == 'no_dense':
188
189
        flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair')
190
        t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1')
191
192
    elif model_name == 'dense24':
193
194
        flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair')
195
        t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1')
196
    else:
197
        print' No such model name '
198
199
    t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15])
200
    t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27])
201
202
    flair_t2_15 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_15_cls')(flair_t2_15)
203
    flair_t2_27 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_27_cls')(flair_t2_27)
204
    t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15)
205
    t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27)
206
207
    flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \
208
                     flair_t2_27[:, 13:25, 13:25, 13:25, :]
209
210
    t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \
211
                    t1_t1ce_27[:, 13:25, 13:25, 13:25, :]
212
213
    loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \
214
           segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5)
215
216
    acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node)
217
    acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node)
218
219
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
220
    with tf.control_dependencies(update_ops):
221
        optimizer = tf.train.AdamOptimizer(learning_rate=5e-4).minimize(loss)
222
223
    saver = tf.train.Saver(max_to_keep=15)
224
    data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200,correction = options['correction'])
225
226
    with tf.Session() as sess:
227
        if continue_training:
228
            saver.restore(sess, LOAD_PATH)
229
        else:
230
            sess.run(tf.global_variables_initializer())
231
        for ei in range(NUM_EPOCHS):
232
            for pi in range(len(files)):
233
                acc_pi, loss_pi = [], []
234
                data, labels, centers = data_gen_train.next()
235
                n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE))
236
                for nb in range(n_batches):
237
                    offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE)
238
                    data_batch, label_batch = get_patches_3d(data, labels, centers[:, offset_batch:offset_batch + BATCH_SIZE], HSIZE, WSIZE, CSIZE, PSIZE, False)
239
                    label_batch = label_transform(label_batch, 5)
240
                    _, l, acc_ft, acc_t1c = sess.run(fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce],
241
                                                   feed_dict={flair_t2_node: data_batch[:, :, :, :, :2],
242
                                                              t1_t1ce_node: data_batch[:, :, :, :, 2:],
243
                                                              flair_t2_gt_node: label_batch[0],
244
                                                              t1_t1ce_gt_node: label_batch[1],
245
                                                              learning_phase(): 1})
246
                    acc_pi.append([acc_ft, acc_t1c])
247
                    loss_pi.append(l)
248
                    n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0)
249
                    print 'epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \
250
                          (ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c)
251
252
                print 'patient loss: %.4f, patient acc: %.4f' % (np.mean(loss_pi), np.mean(acc_pi))
253
254
            saver.save(sess, SAVE_PATH, global_step=ei)
255
            print 'model saved'
256
257
258
if __name__ == '__main__':
259
    
260
    train()