|
a |
|
b/train.py |
|
|
1 |
import numpy as np |
|
|
2 |
import tf_models |
|
|
3 |
from sklearn.preprocessing import scale |
|
|
4 |
import tensorflow as tf |
|
|
5 |
from tensorflow.contrib.keras.python.keras.backend import learning_phase |
|
|
6 |
from tensorflow.contrib.keras.python.keras.layers import concatenate, Conv3D |
|
|
7 |
from nibabel import load as load_nii |
|
|
8 |
import os |
|
|
9 |
import argparse |
|
|
10 |
import keras |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def parse_inputs(): |
|
|
14 |
|
|
|
15 |
parser = argparse.ArgumentParser(description='train the model') |
|
|
16 |
parser.add_argument('-r', '--root-path', dest='root_path', default='/media/lele/Data/spie/Brats17TrainingData/HGG') |
|
|
17 |
parser.add_argument('-sp', '--save-path', dest='save_path', default='dense24_correction') |
|
|
18 |
parser.add_argument('-lp', '--load-path', dest='load_path', default='dense24_correction') |
|
|
19 |
parser.add_argument('-ow', '--offset-width', dest='offset_w', type=int, default=12) |
|
|
20 |
parser.add_argument('-oh', '--offset-height', dest='offset_h', type=int, default=12) |
|
|
21 |
parser.add_argument('-oc', '--offset-channel', dest='offset_c', nargs='+', type=int, default=12) |
|
|
22 |
parser.add_argument('-ws', '--width-size', dest='wsize', type=int, default=38) |
|
|
23 |
parser.add_argument('-hs', '--height-size', dest='hsize', type=int, default=38) |
|
|
24 |
parser.add_argument('-cs', '--channel-size', dest='csize', type=int, default=38) |
|
|
25 |
parser.add_argument('-ps', '--pred-size', dest='psize', type=int, default=12) |
|
|
26 |
parser.add_argument('-bs', '--batch-size', dest='batch_size', type=int, default=2) |
|
|
27 |
parser.add_argument('-e', '--num-epochs', dest='num_epochs', type=int, default=5) |
|
|
28 |
parser.add_argument('-c', '--continue-training', dest='continue_training', type=bool, default=False) |
|
|
29 |
parser.add_argument('-mn', '--model_name', dest='model_name', type=str, default='dense24') |
|
|
30 |
parser.add_argument('-nc', '--n4correction', dest='correction', type=bool, default=False) |
|
|
31 |
parser.add_argument('-gpu', '--gpu_id', dest='gpu_id', type=str, default='0') |
|
|
32 |
return vars(parser.parse_args()) |
|
|
33 |
|
|
|
34 |
options = parse_inputs() |
|
|
35 |
|
|
|
36 |
os.environ["CUDA_VISIBLE_DEVICES"] = options['gpu_id'] |
|
|
37 |
def acc_tf(y_pred, y_true): |
|
|
38 |
correct_prediction = tf.equal(tf.cast(tf.argmax(y_pred, -1), tf.int32), tf.cast(tf.argmax(y_true, -1), tf.int32)) |
|
|
39 |
return 100 * tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
def get_patches_3d(data, labels, centers, hsize, wsize, csize, psize, preprocess=True): |
|
|
43 |
""" |
|
|
44 |
|
|
|
45 |
:param data: 4D nparray (h, w, c, ?) |
|
|
46 |
:param centers: |
|
|
47 |
:param hsize: |
|
|
48 |
:param wsize: |
|
|
49 |
:param csize: |
|
|
50 |
:return: |
|
|
51 |
""" |
|
|
52 |
patches_x, patches_y = [], [] |
|
|
53 |
offset_p = (hsize - psize) / 2 |
|
|
54 |
for i in range(len(centers[0])): |
|
|
55 |
h, w, c = centers[0, i], centers[1, i], centers[2, i] |
|
|
56 |
h_beg = min(max(0, h - hsize / 2), 240 - hsize) |
|
|
57 |
w_beg = min(max(0, w - wsize / 2), 240 - wsize) |
|
|
58 |
c_beg = min(max(0, c - csize / 2), 155 - csize) |
|
|
59 |
ph_beg = h_beg + offset_p |
|
|
60 |
pw_beg = w_beg + offset_p |
|
|
61 |
pc_beg = c_beg + offset_p |
|
|
62 |
vox = data[h_beg:h_beg + hsize, w_beg:w_beg + wsize, c_beg:c_beg + csize, :] |
|
|
63 |
vox_labels = labels[ph_beg:ph_beg + psize, pw_beg:pw_beg + psize, pc_beg:pc_beg + psize] |
|
|
64 |
patches_x.append(vox) |
|
|
65 |
patches_y.append(vox_labels) |
|
|
66 |
return np.array(patches_x), np.array(patches_y) |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
def positive_ratio(x): |
|
|
70 |
return float(np.sum(np.greater(x, 0))) / np.prod(x.shape) |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def norm(image): |
|
|
74 |
image = np.squeeze(image) |
|
|
75 |
image_nonzero = image[np.nonzero(image)] |
|
|
76 |
return (image - image_nonzero.mean()) / image_nonzero.std() |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
def segmentation_loss(y_true, y_pred, n_classes): |
|
|
80 |
y_true = tf.reshape(y_true, (-1, n_classes)) |
|
|
81 |
y_pred = tf.reshape(y_pred, (-1, n_classes)) |
|
|
82 |
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, |
|
|
83 |
logits=y_pred)) |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
def vox_preprocess(vox): |
|
|
87 |
vox_shape = vox.shape |
|
|
88 |
vox = np.reshape(vox, (-1, vox_shape[-1])) |
|
|
89 |
vox = scale(vox, axis=0) |
|
|
90 |
return np.reshape(vox, vox_shape) |
|
|
91 |
|
|
|
92 |
|
|
|
93 |
def one_hot(y, num_classes): |
|
|
94 |
y_ = np.zeros([len(y), num_classes]) |
|
|
95 |
y_[np.arange(len(y)), y] = 1 |
|
|
96 |
return y_ |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def dice_coef_np(y_true, y_pred, num_classes): |
|
|
100 |
""" |
|
|
101 |
|
|
|
102 |
:param y_true: sparse labels |
|
|
103 |
:param y_pred: sparse labels |
|
|
104 |
:param num_classes: number of classes |
|
|
105 |
:return: |
|
|
106 |
""" |
|
|
107 |
y_true = y_true.astype(int) |
|
|
108 |
y_pred = y_pred.astype(int) |
|
|
109 |
y_true = y_true.flatten() |
|
|
110 |
y_true = one_hot(y_true, num_classes) |
|
|
111 |
y_pred = y_pred.flatten() |
|
|
112 |
y_pred = one_hot(y_pred, num_classes) |
|
|
113 |
intersection = np.sum(y_true * y_pred, axis=0) |
|
|
114 |
return (2. * intersection) / (np.sum(y_true, axis=0) + np.sum(y_pred, axis=0)) |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
def vox_generator(all_files, n_pos, n_neg,correction= False): |
|
|
118 |
path = options['root_path'] |
|
|
119 |
while 1: |
|
|
120 |
for file in all_files: |
|
|
121 |
if correction: |
|
|
122 |
flair = load_nii(os.path.join(path, file, file + '_flair_corrected.nii.gz')).get_data() |
|
|
123 |
t2 = load_nii(os.path.join(path, file, file + '_t2_corrected.nii.gz')).get_data() |
|
|
124 |
t1 = load_nii(os.path.join(path, file, file + '_t1_corrected.nii.gz')).get_data() |
|
|
125 |
t1ce = load_nii(os.path.join(path, file, file + '_t1ce_corrected.nii.gz')).get_data() |
|
|
126 |
else: |
|
|
127 |
|
|
|
128 |
flair = load_nii(os.path.join(path, file, file + '_flair.nii.gz')).get_data() |
|
|
129 |
t2 = load_nii(os.path.join(path, file, file + '_t2.nii.gz')).get_data() |
|
|
130 |
t1 = load_nii(os.path.join(path, file, file + '_t1.nii.gz')).get_data() |
|
|
131 |
t1ce = load_nii(os.path.join(path, file, file + '_t1ce.nii.gz')).get_data() |
|
|
132 |
|
|
|
133 |
data_norm = np.array([norm(flair), norm(t2), norm(t1), norm(t1ce)]) |
|
|
134 |
data_norm = np.transpose(data_norm, axes=[1, 2, 3, 0]) |
|
|
135 |
labels = load_nii(os.path.join(path, file, file+'_seg.nii.gz')).get_data() |
|
|
136 |
|
|
|
137 |
foreground = np.array(np.where(labels > 0)) |
|
|
138 |
background = np.array(np.where((labels == 0) & (flair > 0))) |
|
|
139 |
|
|
|
140 |
# n_pos = int(foreground.shape[1] * discount) |
|
|
141 |
foreground = foreground[:, np.random.permutation(foreground.shape[1])[:n_pos]] |
|
|
142 |
background = background[:, np.random.permutation(background.shape[1])[:n_neg]] |
|
|
143 |
|
|
|
144 |
centers = np.concatenate((foreground, background), axis=1) |
|
|
145 |
centers = centers[:, np.random.permutation(n_neg+n_pos)] |
|
|
146 |
|
|
|
147 |
yield data_norm, labels, centers |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
def label_transform(y, nlabels): |
|
|
151 |
return [ |
|
|
152 |
keras.utils.to_categorical(np.copy(y).astype(dtype=np.bool), |
|
|
153 |
num_classes=2).reshape([y.shape[0], y.shape[1], y.shape[2], y.shape[3], 2]), |
|
|
154 |
|
|
|
155 |
keras.utils.to_categorical(y, |
|
|
156 |
num_classes=nlabels).reshape([y.shape[0], y.shape[1], y.shape[2], y.shape[3], nlabels]) |
|
|
157 |
] |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
def train(): |
|
|
161 |
NUM_EPOCHS = options['num_epochs'] |
|
|
162 |
LOAD_PATH = options['load_path'] |
|
|
163 |
SAVE_PATH = options['save_path'] |
|
|
164 |
PSIZE = options['psize'] |
|
|
165 |
HSIZE = options['hsize'] |
|
|
166 |
WSIZE = options['wsize'] |
|
|
167 |
CSIZE = options['csize'] |
|
|
168 |
model_name= options['model_name'] |
|
|
169 |
BATCH_SIZE = options['batch_size'] |
|
|
170 |
continue_training = options['continue_training'] |
|
|
171 |
|
|
|
172 |
files = [] |
|
|
173 |
num_labels = 5 |
|
|
174 |
with open('train.txt') as f: |
|
|
175 |
for line in f: |
|
|
176 |
files.append(line[:-1]) |
|
|
177 |
print '%d training samples' % len(files) |
|
|
178 |
|
|
|
179 |
flair_t2_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) |
|
|
180 |
t1_t1ce_node = tf.placeholder(dtype=tf.float32, shape=(None, HSIZE, WSIZE, CSIZE, 2)) |
|
|
181 |
flair_t2_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 2)) |
|
|
182 |
t1_t1ce_gt_node = tf.placeholder(dtype=tf.int32, shape=(None, PSIZE, PSIZE, PSIZE, 5)) |
|
|
183 |
|
|
|
184 |
if model_name == 'dense48': |
|
|
185 |
flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=flair_t2_node, name='flair') |
|
|
186 |
t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat_large(input=t1_t1ce_node, name='t1') |
|
|
187 |
elif model_name == 'no_dense': |
|
|
188 |
|
|
|
189 |
flair_t2_15, flair_t2_27 = tf_models.PlainCounterpart(input=flair_t2_node, name='flair') |
|
|
190 |
t1_t1ce_15, t1_t1ce_27 = tf_models.PlainCounterpart(input=t1_t1ce_node, name='t1') |
|
|
191 |
|
|
|
192 |
elif model_name == 'dense24': |
|
|
193 |
|
|
|
194 |
flair_t2_15, flair_t2_27 = tf_models.BraTS2ScaleDenseNetConcat(input=flair_t2_node, name='flair') |
|
|
195 |
t1_t1ce_15, t1_t1ce_27 = tf_models.BraTS2ScaleDenseNetConcat(input=t1_t1ce_node, name='t1') |
|
|
196 |
else: |
|
|
197 |
print' No such model name ' |
|
|
198 |
|
|
|
199 |
t1_t1ce_15 = concatenate([t1_t1ce_15, flair_t2_15]) |
|
|
200 |
t1_t1ce_27 = concatenate([t1_t1ce_27, flair_t2_27]) |
|
|
201 |
|
|
|
202 |
flair_t2_15 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_15_cls')(flair_t2_15) |
|
|
203 |
flair_t2_27 = Conv3D(2, kernel_size=1, strides=1, padding='same', name='flair_t2_27_cls')(flair_t2_27) |
|
|
204 |
t1_t1ce_15 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_15_cls')(t1_t1ce_15) |
|
|
205 |
t1_t1ce_27 = Conv3D(num_labels, kernel_size=1, strides=1, padding='same', name='t1_t1ce_27_cls')(t1_t1ce_27) |
|
|
206 |
|
|
|
207 |
flair_t2_score = flair_t2_15[:, 13:25, 13:25, 13:25, :] + \ |
|
|
208 |
flair_t2_27[:, 13:25, 13:25, 13:25, :] |
|
|
209 |
|
|
|
210 |
t1_t1ce_score = t1_t1ce_15[:, 13:25, 13:25, 13:25, :] + \ |
|
|
211 |
t1_t1ce_27[:, 13:25, 13:25, 13:25, :] |
|
|
212 |
|
|
|
213 |
loss = segmentation_loss(flair_t2_gt_node, flair_t2_score, 2) + \ |
|
|
214 |
segmentation_loss(t1_t1ce_gt_node, t1_t1ce_score, 5) |
|
|
215 |
|
|
|
216 |
acc_flair_t2 = acc_tf(y_pred=flair_t2_score, y_true=flair_t2_gt_node) |
|
|
217 |
acc_t1_t1ce = acc_tf(y_pred=t1_t1ce_score, y_true=t1_t1ce_gt_node) |
|
|
218 |
|
|
|
219 |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
|
|
220 |
with tf.control_dependencies(update_ops): |
|
|
221 |
optimizer = tf.train.AdamOptimizer(learning_rate=5e-4).minimize(loss) |
|
|
222 |
|
|
|
223 |
saver = tf.train.Saver(max_to_keep=15) |
|
|
224 |
data_gen_train = vox_generator(all_files=files, n_pos=200, n_neg=200,correction = options['correction']) |
|
|
225 |
|
|
|
226 |
with tf.Session() as sess: |
|
|
227 |
if continue_training: |
|
|
228 |
saver.restore(sess, LOAD_PATH) |
|
|
229 |
else: |
|
|
230 |
sess.run(tf.global_variables_initializer()) |
|
|
231 |
for ei in range(NUM_EPOCHS): |
|
|
232 |
for pi in range(len(files)): |
|
|
233 |
acc_pi, loss_pi = [], [] |
|
|
234 |
data, labels, centers = data_gen_train.next() |
|
|
235 |
n_batches = int(np.ceil(float(centers.shape[1]) / BATCH_SIZE)) |
|
|
236 |
for nb in range(n_batches): |
|
|
237 |
offset_batch = min(nb * BATCH_SIZE, centers.shape[1] - BATCH_SIZE) |
|
|
238 |
data_batch, label_batch = get_patches_3d(data, labels, centers[:, offset_batch:offset_batch + BATCH_SIZE], HSIZE, WSIZE, CSIZE, PSIZE, False) |
|
|
239 |
label_batch = label_transform(label_batch, 5) |
|
|
240 |
_, l, acc_ft, acc_t1c = sess.run(fetches=[optimizer, loss, acc_flair_t2, acc_t1_t1ce], |
|
|
241 |
feed_dict={flair_t2_node: data_batch[:, :, :, :, :2], |
|
|
242 |
t1_t1ce_node: data_batch[:, :, :, :, 2:], |
|
|
243 |
flair_t2_gt_node: label_batch[0], |
|
|
244 |
t1_t1ce_gt_node: label_batch[1], |
|
|
245 |
learning_phase(): 1}) |
|
|
246 |
acc_pi.append([acc_ft, acc_t1c]) |
|
|
247 |
loss_pi.append(l) |
|
|
248 |
n_pos_sum = np.sum(np.reshape(label_batch[0], (-1, 2)), axis=0) |
|
|
249 |
print 'epoch-patient: %d, %d, iter: %d-%d, p%%: %.4f, loss: %.4f, acc_flair_t2: %.2f%%, acc_t1_t1ce: %.2f%%' % \ |
|
|
250 |
(ei + 1, pi + 1, nb + 1, n_batches, n_pos_sum[1]/float(np.sum(n_pos_sum)), l, acc_ft, acc_t1c) |
|
|
251 |
|
|
|
252 |
print 'patient loss: %.4f, patient acc: %.4f' % (np.mean(loss_pi), np.mean(acc_pi)) |
|
|
253 |
|
|
|
254 |
saver.save(sess, SAVE_PATH, global_step=ei) |
|
|
255 |
print 'model saved' |
|
|
256 |
|
|
|
257 |
|
|
|
258 |
if __name__ == '__main__': |
|
|
259 |
|
|
|
260 |
train() |