Diff of /train.py [000000] .. [7b5b9f]

Switch to unified view

a b/train.py
1
import numpy as np
2
import tensorflow as tf
3
from tensorflow.python.platform import flags
4
from data_generator import ImageDataGenerator
5
import logging
6
from utils import _eval_dice, _connectivity_region_analysis, parse_fn, _crop_object_region, _get_coutour_sample, parse_fn_haus,_eval_haus
7
import time 
8
import os
9
import SimpleITK as sitk
10
11
def train(model, saver, sess, train_file_list, test_file, args, resume_itr=0):
12
13
    if args.log:
14
        train_writer = tf.summary.FileWriter(args.log_dir + '/' + args.phase + '/', sess.graph)
15
16
    # Data loaders
17
    with tf.device('/cpu:0'):
18
        tr_data_list, train_iterator_list, train_next_list = [],[],[]
19
        for i in range(len(train_file_list)):
20
            tr_data = ImageDataGenerator(train_file_list[i], mode='training', \
21
                                         batch_size=args.meta_batch_size, num_classes=args.n_class, shuffle=True)
22
            tr_data_list.append(tr_data)
23
            train_iterator_list.append(tf.data.Iterator.from_structure(tr_data.data.output_types,tr_data.data.output_shapes))
24
            train_next_list.append(train_iterator_list[i].get_next())
25
26
    # Ops for initializing different iterators
27
    training_init_op = []
28
    train_batches_per_epoch = []
29
    for i in range(len(train_file_list)):
30
        training_init_op.append(train_iterator_list[i].make_initializer(tr_data_list[i].data))
31
        sess.run(training_init_op[i])  # initialize training sample generator at itr=0
32
33
    # Training begins
34
    best_test_dice = 0
35
    best_test_haus = 0
36
    for epoch in xrange(0, args.epoch):
37
        for itr in range(resume_itr, args.train_iterations):
38
            start = time.time()
39
            # Sampling training and test tasks
40
            num_training_tasks = len(train_file_list)
41
            num_meta_train = 2#num_training_tasks-1
42
            num_meta_test = 1#num_training_tasks-num_meta_train  # as setting num_meta_test = 1
43
44
            # Randomly choosing meta train and meta test domains
45
            task_list = np.random.permutation(num_training_tasks)
46
            meta_train_index_list = task_list[:2]
47
            meta_test_index_list = task_list[-1:]
48
49
            # Sampling meta-train, meta-test data
50
            for i in range(num_meta_train):
51
                task_ind = meta_train_index_list[i]
52
                if i == 0:
53
                    inputa, labela = sess.run(train_next_list[task_ind])
54
                elif i == 1:
55
                    inputa1, labela1 = sess.run(train_next_list[task_ind])
56
                else:
57
                    raise RuntimeError('check number of meta-train domains.')
58
59
            for i in range(num_meta_test):
60
                task_ind = meta_test_index_list[i]
61
                if i == 0:
62
                    inputb, labelb = sess.run(train_next_list[task_ind])
63
                else:
64
                    raise RuntimeError('check number of meta-test domains.')
65
            
66
            input_group = np.concatenate((inputa[:2],inputa1[:1],inputb[:2]), axis=0)
67
            label_group = np.concatenate((labela[:2],labela1[:1],labelb[:2]), axis=0)
68
69
            contour_group, metric_label_group = _get_coutour_sample(label_group)
70
71
            feed_dict = {model.inputa: inputa, model.labela: labela, \
72
                         model.inputa1: inputa1, model.labela1: labela1, \
73
                         model.inputb: inputb, model.labelb: labelb, \
74
                         model.input_group:input_group, \
75
                         model.label_group:label_group, \
76
                         model.contour_group:contour_group, \
77
                         model.metric_label_group:metric_label_group, \
78
                         model.KEEP_PROB: 1.0}
79
80
            output_tensors = [model.task_train_op, model.meta_train_op, model.metric_train_op]
81
            output_tensors.extend([model.summ_op, model.seg_loss_b, model.compactness_loss_b, model.smoothness_loss_b, model.target_loss, model.source_loss])
82
            _, _, _, summ_writer, seg_loss_b, compactness_loss_b, smoothness_loss_b, target_loss, source_loss = sess.run(output_tensors, feed_dict)
83
            # output_tensors = [model.task_train_op]
84
            # output_tensors.extend([model.source_loss])
85
            # _, source_loss = sess.run(output_tensors, feed_dict)
86
87
            if itr % args.print_interval == 0:
88
                logging.info("Epoch: [%2d] [%6d/%6d] time: %4.4f inner lr:%.8f outer lr:%.8f" % (epoch, itr, args.train_iterations, (time.time()-start), model.inner_lr.eval(), model.outer_lr.eval()))
89
                logging.info('sou_loss: %.7f, tar_loss: %.7f, tar_seg_loss: %.7f, tar_compactness_loss: %.7f, tar_smoothness_loss: %.7f' % (source_loss, target_loss, seg_loss_b, compactness_loss_b, smoothness_loss_b))
90
91
            if itr % args.summary_interval == 0:
92
                train_writer.add_summary(summ_writer, itr)
93
                train_writer.flush()
94
95
            if (itr!=0) and itr % args.save_freq == 0:
96
                saver.save(sess, args.checkpoint_dir + '/epoch_' + str(epoch) + '_itr_'+str(itr) + ".model.cpkt")
97
98
            # Testing periodically
99
            if (itr!=0) and itr % args.test_freq == 0:
100
                test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file, model, args)
101
102
                if test_dice > best_test_dice:
103
                    best_test_dice = test_dice
104
105
                with open((os.path.join(args.log_dir,'eva.txt')), 'a') as f:
106
                    print >> f, 'Iteration %d :' % (itr)
107
                    print >> f, '   Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr
108
                    print >> f, '   Current best accuracy %f' %(best_test_dice)
109
                    print >> f, '   Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr
110
                    print >> f, '   Current best accuracy %f' %(best_test_haus)
111
                # Save model
112
113
def test(sess, test_list, model, args):
114
    
115
    dice = []
116
    haus = []
117
    start = time.time()
118
119
    with open(test_list, 'r') as fp:
120
        rows = fp.readlines()
121
    test_list  = [row[:-1] if row[-1] == '\n' else row for row in rows]
122
123
    for fid, filename in enumerate(test_list):
124
        image, mask, spacing = parse_fn_haus(filename)
125
        pred_y = np.zeros(mask.shape)
126
127
        frame_list = [kk for kk in range(1, image.shape[2] - 1)]
128
129
        for ii in xrange(int(np.floor(image.shape[2] // model.test_batch_size))):
130
            vol = np.zeros([model.test_batch_size, model.volume_size[0], model.volume_size[1], model.volume_size[2]])
131
132
            for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]):
133
                vol[idx, ...] = image[..., jj - 1: jj + 2].copy()
134
135
            pred_student = sess.run((model.outputs), feed_dict={model.test_input: vol, \
136
                                                                    model.KEEP_PROB: 1.0,\
137
                                                                    model.training_mode: True})
138
139
            for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]):
140
                pred_y[..., jj] = pred_student[idx, ...].copy()
141
142
        processed_pred_y = _connectivity_region_analysis(pred_y)
143
144
        dice_subject = _eval_dice(mask, processed_pred_y)
145
146
        # print spacing
147
        dice.append(dice_subject)
148
        # haus.append(haus_subject)
149
        # _save_nii_prediction(mask, processed_pred_y, pred_y, args.result_dir, '_' + filename[-26:-20])
150
    dice_avg = np.mean(dice, axis=0).tolist()[0]
151
    # haus_avg = np.mean(haus, axis=0).tolist()[0]
152
    
153
    logging.info("dice_avg %.4f" % (dice_avg))
154
    # logging.info("haus_avg %.4f" % (haus_avg))
155
156
    return dice_avg, dice, 0, 0
157
    # return dice_avg, dice, haus_avg, haus
158
159
def _save_nii_prediction(gth, comp_pred, pre_pred, out_folder, out_bname):
160
    sitk.WriteImage(sitk.GetImageFromArray(gth), out_folder + out_bname + 'gth.nii.gz') 
161
    sitk.WriteImage(sitk.GetImageFromArray(pre_pred), out_folder + out_bname + 'premask.nii.gz') 
162
    sitk.WriteImage(sitk.GetImageFromArray(comp_pred), out_folder + out_bname + 'mask.nii.gz')