--- a
+++ b/train.py
@@ -0,0 +1,162 @@
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.platform import flags
+from data_generator import ImageDataGenerator
+import logging
+from utils import _eval_dice, _connectivity_region_analysis, parse_fn, _crop_object_region, _get_coutour_sample, parse_fn_haus,_eval_haus
+import time 
+import os
+import SimpleITK as sitk
+
+def train(model, saver, sess, train_file_list, test_file, args, resume_itr=0):
+
+    if args.log:
+        train_writer = tf.summary.FileWriter(args.log_dir + '/' + args.phase + '/', sess.graph)
+
+    # Data loaders
+    with tf.device('/cpu:0'):
+        tr_data_list, train_iterator_list, train_next_list = [],[],[]
+        for i in range(len(train_file_list)):
+            tr_data = ImageDataGenerator(train_file_list[i], mode='training', \
+                                         batch_size=args.meta_batch_size, num_classes=args.n_class, shuffle=True)
+            tr_data_list.append(tr_data)
+            train_iterator_list.append(tf.data.Iterator.from_structure(tr_data.data.output_types,tr_data.data.output_shapes))
+            train_next_list.append(train_iterator_list[i].get_next())
+
+    # Ops for initializing different iterators
+    training_init_op = []
+    train_batches_per_epoch = []
+    for i in range(len(train_file_list)):
+        training_init_op.append(train_iterator_list[i].make_initializer(tr_data_list[i].data))
+        sess.run(training_init_op[i])  # initialize training sample generator at itr=0
+
+    # Training begins
+    best_test_dice = 0
+    best_test_haus = 0
+    for epoch in xrange(0, args.epoch):
+    	for itr in range(resume_itr, args.train_iterations):
+			start = time.time()
+			# Sampling training and test tasks
+			num_training_tasks = len(train_file_list)
+			num_meta_train = 2#num_training_tasks-1
+			num_meta_test = 1#num_training_tasks-num_meta_train  # as setting num_meta_test = 1
+
+			# Randomly choosing meta train and meta test domains
+			task_list = np.random.permutation(num_training_tasks)
+			meta_train_index_list = task_list[:2]
+			meta_test_index_list = task_list[-1:]
+
+			# Sampling meta-train, meta-test data
+			for i in range(num_meta_train):
+			    task_ind = meta_train_index_list[i]
+			    if i == 0:
+			        inputa, labela = sess.run(train_next_list[task_ind])
+			    elif i == 1:
+			        inputa1, labela1 = sess.run(train_next_list[task_ind])
+			    else:
+			        raise RuntimeError('check number of meta-train domains.')
+
+			for i in range(num_meta_test):
+			    task_ind = meta_test_index_list[i]
+			    if i == 0:
+			        inputb, labelb = sess.run(train_next_list[task_ind])
+			    else:
+			        raise RuntimeError('check number of meta-test domains.')
+			
+			input_group = np.concatenate((inputa[:2],inputa1[:1],inputb[:2]), axis=0)
+			label_group = np.concatenate((labela[:2],labela1[:1],labelb[:2]), axis=0)
+
+			contour_group, metric_label_group = _get_coutour_sample(label_group)
+
+			feed_dict = {model.inputa: inputa, model.labela: labela, \
+			             model.inputa1: inputa1, model.labela1: labela1, \
+			             model.inputb: inputb, model.labelb: labelb, \
+			             model.input_group:input_group, \
+			             model.label_group:label_group, \
+			             model.contour_group:contour_group, \
+			             model.metric_label_group:metric_label_group, \
+			             model.KEEP_PROB: 1.0}
+
+			output_tensors = [model.task_train_op, model.meta_train_op, model.metric_train_op]
+			output_tensors.extend([model.summ_op, model.seg_loss_b, model.compactness_loss_b, model.smoothness_loss_b, model.target_loss, model.source_loss])
+			_, _, _, summ_writer, seg_loss_b, compactness_loss_b, smoothness_loss_b, target_loss, source_loss = sess.run(output_tensors, feed_dict)
+			# output_tensors = [model.task_train_op]
+			# output_tensors.extend([model.source_loss])
+			# _, source_loss = sess.run(output_tensors, feed_dict)
+
+			if itr % args.print_interval == 0:
+			    logging.info("Epoch: [%2d] [%6d/%6d] time: %4.4f inner lr:%.8f outer lr:%.8f" % (epoch, itr, args.train_iterations, (time.time()-start), model.inner_lr.eval(), model.outer_lr.eval()))
+			    logging.info('sou_loss: %.7f, tar_loss: %.7f, tar_seg_loss: %.7f, tar_compactness_loss: %.7f, tar_smoothness_loss: %.7f' % (source_loss, target_loss, seg_loss_b, compactness_loss_b, smoothness_loss_b))
+
+			if itr % args.summary_interval == 0:
+			    train_writer.add_summary(summ_writer, itr)
+			    train_writer.flush()
+
+			if (itr!=0) and itr % args.save_freq == 0:
+			    saver.save(sess, args.checkpoint_dir + '/epoch_' + str(epoch) + '_itr_'+str(itr) + ".model.cpkt")
+
+			# Testing periodically
+			if (itr!=0) and itr % args.test_freq == 0:
+			    test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file, model, args)
+
+			    if test_dice > best_test_dice:
+			        best_test_dice = test_dice
+
+			    with open((os.path.join(args.log_dir,'eva.txt')), 'a') as f:
+			        print >> f, 'Iteration %d :' % (itr)
+			        print >> f, '	Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr
+			        print >> f, '	Current best accuracy %f' %(best_test_dice)
+			        print >> f, '	Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr
+			        print >> f, '	Current best accuracy %f' %(best_test_haus)
+			    # Save model
+
+def test(sess, test_list, model, args):
+    
+    dice = []
+    haus = []
+    start = time.time()
+
+    with open(test_list, 'r') as fp:
+        rows = fp.readlines()
+    test_list  = [row[:-1] if row[-1] == '\n' else row for row in rows]
+
+    for fid, filename in enumerate(test_list):
+        image, mask, spacing = parse_fn_haus(filename)
+        pred_y = np.zeros(mask.shape)
+
+        frame_list = [kk for kk in range(1, image.shape[2] - 1)]
+
+        for ii in xrange(int(np.floor(image.shape[2] // model.test_batch_size))):
+            vol = np.zeros([model.test_batch_size, model.volume_size[0], model.volume_size[1], model.volume_size[2]])
+
+            for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]):
+                vol[idx, ...] = image[..., jj - 1: jj + 2].copy()
+
+            pred_student = sess.run((model.outputs), feed_dict={model.test_input: vol, \
+                                                                    model.KEEP_PROB: 1.0,\
+                                                                    model.training_mode: True})
+
+            for idx, jj in enumerate(frame_list[ii * model.test_batch_size: (ii + 1) * model.test_batch_size]):
+                pred_y[..., jj] = pred_student[idx, ...].copy()
+
+        processed_pred_y = _connectivity_region_analysis(pred_y)
+
+        dice_subject = _eval_dice(mask, processed_pred_y)
+
+        # print spacing
+        dice.append(dice_subject)
+        # haus.append(haus_subject)
+        # _save_nii_prediction(mask, processed_pred_y, pred_y, args.result_dir, '_' + filename[-26:-20])
+    dice_avg = np.mean(dice, axis=0).tolist()[0]
+    # haus_avg = np.mean(haus, axis=0).tolist()[0]
+    
+    logging.info("dice_avg %.4f" % (dice_avg))
+    # logging.info("haus_avg %.4f" % (haus_avg))
+
+    return dice_avg, dice, 0, 0
+    # return dice_avg, dice, haus_avg, haus
+
+def _save_nii_prediction(gth, comp_pred, pre_pred, out_folder, out_bname):
+    sitk.WriteImage(sitk.GetImageFromArray(gth), out_folder + out_bname + 'gth.nii.gz') 
+    sitk.WriteImage(sitk.GetImageFromArray(pre_pred), out_folder + out_bname + 'premask.nii.gz') 
+    sitk.WriteImage(sitk.GetImageFromArray(comp_pred), out_folder + out_bname + 'mask.nii.gz')