|
a |
|
b/train.py |
|
|
1 |
import numpy as np |
|
|
2 |
import tensorflow as tf |
|
|
3 |
from tensorflow.python.platform import flags |
|
|
4 |
from data_generator import ImageDataGenerator |
|
|
5 |
import logging |
|
|
6 |
from utils import _eval_dice, _connectivity_region_analysis, parse_fn, _crop_object_region, _get_coutour_sample, parse_fn_haus,_eval_haus |
|
|
7 |
import time |
|
|
8 |
import os |
|
|
9 |
import SimpleITK as sitk |
|
|
10 |
|
|
|
11 |
def train(model, saver, sess, train_file_list, test_file, args, resume_itr=0): |
|
|
12 |
|
|
|
13 |
if args.log: |
|
|
14 |
train_writer = tf.summary.FileWriter(args.log_dir + '/' + args.phase + '/', sess.graph) |
|
|
15 |
|
|
|
16 |
# Data loaders |
|
|
17 |
with tf.device('/cpu:0'): |
|
|
18 |
tr_data_list, train_iterator_list, train_next_list = [],[],[] |
|
|
19 |
for i in range(len(train_file_list)): |
|
|
20 |
tr_data = ImageDataGenerator(train_file_list[i], mode='training', \ |
|
|
21 |
batch_size=args.meta_batch_size, num_classes=args.n_class, shuffle=True) |
|
|
22 |
tr_data_list.append(tr_data) |
|
|
23 |
train_iterator_list.append(tf.data.Iterator.from_structure(tr_data.data.output_types,tr_data.data.output_shapes)) |
|
|
24 |
train_next_list.append(train_iterator_list[i].get_next()) |
|
|
25 |
|
|
|
26 |
# Ops for initializing different iterators |
|
|
27 |
training_init_op = [] |
|
|
28 |
train_batches_per_epoch = [] |
|
|
29 |
for i in range(len(train_file_list)): |
|
|
30 |
training_init_op.append(train_iterator_list[i].make_initializer(tr_data_list[i].data)) |
|
|
31 |
sess.run(training_init_op[i]) # initialize training sample generator at itr=0 |
|
|
32 |
|
|
|
33 |
# Training begins |
|
|
34 |
best_test_dice = 0 |
|
|
35 |
best_test_haus = 0 |
|
|
36 |
for epoch in xrange(0, args.epoch): |
|
|
37 |
for itr in range(resume_itr, args.train_iterations): |
|
|
38 |
start = time.time() |
|
|
39 |
# Sampling training and test tasks |
|
|
40 |
num_training_tasks = len(train_file_list) |
|
|
41 |
num_meta_train = 2#num_training_tasks-1 |
|
|
42 |
num_meta_test = 1#num_training_tasks-num_meta_train # as setting num_meta_test = 1 |
|
|
43 |
|
|
|
44 |
# Randomly choosing meta train and meta test domains |
|
|
45 |
task_list = np.random.permutation(num_training_tasks) |
|
|
46 |
meta_train_index_list = task_list[:2] |
|
|
47 |
meta_test_index_list = task_list[-1:] |
|
|
48 |
|
|
|
49 |
# Sampling meta-train, meta-test data |
|
|
50 |
for i in range(num_meta_train): |
|
|
51 |
task_ind = meta_train_index_list[i] |
|
|
52 |
if i == 0: |
|
|
53 |
inputa, labela = sess.run(train_next_list[task_ind]) |
|
|
54 |
elif i == 1: |
|
|
55 |
inputa1, labela1 = sess.run(train_next_list[task_ind]) |
|
|
56 |
else: |
|
|
57 |
raise RuntimeError('check number of meta-train domains.') |
|
|
58 |
|
|
|
59 |
for i in range(num_meta_test): |
|
|
60 |
task_ind = meta_test_index_list[i] |
|
|
61 |
if i == 0: |
|
|
62 |
inputb, labelb = sess.run(train_next_list[task_ind]) |
|
|
63 |
else: |
|
|
64 |
raise RuntimeError('check number of meta-test domains.') |
|
|
65 |
|
|
|
66 |
input_group = np.concatenate((inputa[:2],inputa1[:1],inputb[:2]), axis=0) |
|
|
67 |
label_group = np.concatenate((labela[:2],labela1[:1],labelb[:2]), axis=0) |
|
|
68 |
|
|
|
69 |
contour_group, metric_label_group = _get_coutour_sample(label_group) |
|
|
70 |
|
|
|
71 |
feed_dict = {model.inputa: inputa, model.labela: labela, \ |
|
|
72 |
model.inputa1: inputa1, model.labela1: labela1, \ |
|
|
73 |
model.inputb: inputb, model.labelb: labelb, \ |
|
|
74 |
model.input_group:input_group, \ |
|
|
75 |
model.label_group:label_group, \ |
|
|
76 |
model.contour_group:contour_group, \ |
|
|
77 |
model.metric_label_group:metric_label_group, \ |
|
|
78 |
model.KEEP_PROB: 1.0} |
|
|
79 |
|
|
|
80 |
output_tensors = [model.task_train_op, model.meta_train_op, model.metric_train_op] |
|
|
81 |
output_tensors.extend([model.summ_op, model.seg_loss_b, model.compactness_loss_b, model.smoothness_loss_b, model.target_loss, model.source_loss]) |
|
|
82 |
_, _, _, summ_writer, seg_loss_b, compactness_loss_b, smoothness_loss_b, target_loss, source_loss = sess.run(output_tensors, feed_dict) |
|
|
83 |
# output_tensors = [model.task_train_op] |
|
|
84 |
# output_tensors.extend([model.source_loss]) |
|
|
85 |
# _, source_loss = sess.run(output_tensors, feed_dict) |
|
|
86 |
|
|
|
87 |
if itr % args.print_interval == 0: |
|
|
88 |
logging.info("Epoch: [%2d] [%6d/%6d] time: %4.4f inner lr:%.8f outer lr:%.8f" % (epoch, itr, args.train_iterations, (time.time()-start), model.inner_lr.eval(), model.outer_lr.eval())) |
|
|
89 |
logging.info('sou_loss: %.7f, tar_loss: %.7f, tar_seg_loss: %.7f, tar_compactness_loss: %.7f, tar_smoothness_loss: %.7f' % (source_loss, target_loss, seg_loss_b, compactness_loss_b, smoothness_loss_b)) |
|
|
90 |
|
|
|
91 |
if itr % args.summary_interval == 0: |
|
|
92 |
train_writer.add_summary(summ_writer, itr) |
|
|
93 |
train_writer.flush() |
|
|
94 |
|
|
|
95 |
if (itr!=0) and itr % args.save_freq == 0: |
|
|
96 |
saver.save(sess, args.checkpoint_dir + '/epoch_' + str(epoch) + '_itr_'+str(itr) + ".model.cpkt") |
|
|
97 |
|
|
|
98 |
# Testing periodically |
|
|
99 |
if (itr!=0) and itr % args.test_freq == 0: |
|
|
100 |
test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file, model, args) |
|
|
101 |
|
|
|
102 |
if test_dice > best_test_dice: |
|
|
103 |
best_test_dice = test_dice |
|
|
104 |
|
|
|
105 |
with open((os.path.join(args.log_dir,'eva.txt')), 'a') as f: |
|
|
106 |
print >> f, 'Iteration %d :' % (itr) |
|
|
107 |
print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr |
|
|
108 |
print >> f, ' Current best accuracy %f' %(best_test_dice) |
|
|
109 |
print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr |
|
|
110 |
print >> f, ' Current best accuracy %f' %(best_test_haus) |
|
|
111 |
# Save model |
|
|
112 |
|
|
|
113 |
def test(sess, test_list, model, args): |
|
|
114 |
|
|
|
115 |
dice = [] |
|
|
116 |
haus = [] |
|
|
117 |
start = time.time() |
|
|
118 |
|
|
|
119 |
with open(test_list, 'r') as fp: |
|
|
120 |
rows = fp.readlines() |
|
|
121 |
test_list = [row[:-1] if row[-1] == '\n' else row for row in rows] |
|
|
122 |
|
|
|
123 |
for fid, filename in enumerate(test_list): |
|
|
124 |
image, mask, spacing = parse_fn_haus(filename) |
|
|
125 |
pred_y = np.zeros(mask.shape) |
|
|
126 |
|
|
|
127 |
frame_list = [kk for kk in range(1, image.shape[2] - 1)] |
|
|
128 |
|
|
|
129 |
for ii in xrange(int(np.floor(image.shape[2] // model.test_batch_size))): |
|
|
130 |
vol = np.zeros([model.test_batch_size, model.volume_size[0], model.volume_size[1], model.volume_size[2]]) |
|
|
131 |
|
|
|
132 |
for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]): |
|
|
133 |
vol[idx, ...] = image[..., jj - 1: jj + 2].copy() |
|
|
134 |
|
|
|
135 |
pred_student = sess.run((model.outputs), feed_dict={model.test_input: vol, \ |
|
|
136 |
model.KEEP_PROB: 1.0,\ |
|
|
137 |
model.training_mode: True}) |
|
|
138 |
|
|
|
139 |
for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]): |
|
|
140 |
pred_y[..., jj] = pred_student[idx, ...].copy() |
|
|
141 |
|
|
|
142 |
processed_pred_y = _connectivity_region_analysis(pred_y) |
|
|
143 |
|
|
|
144 |
dice_subject = _eval_dice(mask, processed_pred_y) |
|
|
145 |
|
|
|
146 |
# print spacing |
|
|
147 |
dice.append(dice_subject) |
|
|
148 |
# haus.append(haus_subject) |
|
|
149 |
# _save_nii_prediction(mask, processed_pred_y, pred_y, args.result_dir, '_' + filename[-26:-20]) |
|
|
150 |
dice_avg = np.mean(dice, axis=0).tolist()[0] |
|
|
151 |
# haus_avg = np.mean(haus, axis=0).tolist()[0] |
|
|
152 |
|
|
|
153 |
logging.info("dice_avg %.4f" % (dice_avg)) |
|
|
154 |
# logging.info("haus_avg %.4f" % (haus_avg)) |
|
|
155 |
|
|
|
156 |
return dice_avg, dice, 0, 0 |
|
|
157 |
# return dice_avg, dice, haus_avg, haus |
|
|
158 |
|
|
|
159 |
def _save_nii_prediction(gth, comp_pred, pre_pred, out_folder, out_bname): |
|
|
160 |
sitk.WriteImage(sitk.GetImageFromArray(gth), out_folder + out_bname + 'gth.nii.gz') |
|
|
161 |
sitk.WriteImage(sitk.GetImageFromArray(pre_pred), out_folder + out_bname + 'premask.nii.gz') |
|
|
162 |
sitk.WriteImage(sitk.GetImageFromArray(comp_pred), out_folder + out_bname + 'mask.nii.gz') |