Diff of /main.py [000000] .. [7b5b9f]

Switch to unified view

a b/main.py
1
import sys
2
import os
3
import numpy as np
4
import tensorflow as tf
5
from tensorflow.python.platform import flags
6
from data_generator import ImageDataGenerator
7
from saml_func import SAML
8
from train import train
9
from train import test
10
import datetime
11
import argparse
12
from utils import check_folder, show_all_variables
13
import logging
14
15
currtime = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
16
tf.set_random_seed(2)
17
18
def parse_args(train_date):
19
    desc = "Tensorflow implementation of DenseUNet for prostate segmentation"
20
    parser = argparse.ArgumentParser(description=desc)
21
    parser.add_argument('--gpu', type=str, default='0', help='train or test or guide')
22
    parser.add_argument('--phase', type=str, default='train', help='train or test or guide')
23
    parser.add_argument('--n_class', type=int, default=2, help='The size of class')
24
25
    ## Training operations
26
    parser.add_argument('--target_domain', type=str, default='ISBI', help='dataset_name')
27
    parser.add_argument('--volume_size', type=list, default=[384, 384, 3], help='The size of input data')
28
    parser.add_argument('--label_size', type=list, default=[384, 384, 1], help='The size of label')
29
    parser.add_argument('--epoch', type=int, default=1, help='The number of epochs to run')
30
    parser.add_argument('--train_iterations', type=int, default=10000, help='The number of training iterations')
31
    parser.add_argument('--meta_batch_size', type=int, default=5, help='number of images sampled per source domain')
32
    parser.add_argument('--test_batch_size', type=int, default=1, help='number of images sampled per source domain')
33
    parser.add_argument('--inner_lr', type=float, default=1e-4, help='The learning rate')
34
    parser.add_argument('--outer_lr', type=float, default=1e-3, help='The learning rate')
35
    parser.add_argument('--metric_lr', type=float, default=1e-3, help='The learning rate')
36
    parser.add_argument('--margin', type=float, default=10.0, help='The learning rate')
37
    parser.add_argument('--compactness_loss_weight', type=float, default=1.0, help='The learning rate')
38
    parser.add_argument('--smoothness_loss_weight', type=float, default=0.005, help='The learning rate')
39
    parser.add_argument('--clipNorm', type=int, default=True, help='number of images sampled per source domain')
40
    parser.add_argument('--gradients_clip_value', type=float, default=10.0, help='The learning rate')
41
42
    # Logging, saving, and testing options
43
    parser.add_argument('--resume', type=int, default=False, help='number of images sampled per source domain')
44
    parser.add_argument('--log', type=int, default=True, help='write tensorboard')
45
    parser.add_argument('--decay_step', type=float, default=500, help='The learning rate')
46
    parser.add_argument('--decay_rate', type=float, default=0.95, help='The learning rate')
47
    parser.add_argument('--test_freq', type=int, default=200, help='The number of ckpt_save_freq')
48
    parser.add_argument('--save_freq', type=int, default=200, help='The number of ckpt_save_freq')
49
    parser.add_argument('--print_interval', type=int, default=5, help='The frequency to write tensorboard')
50
    parser.add_argument('--summary_interval', type=int, default=20, help='The frequency to write tensorboard')
51
    parser.add_argument('--restored_model', type=str, default=None, help='Model to restore')
52
    parser.add_argument('--test_model', type=str, default=None, help='Model to restore')
53
    # parser.add_argument('--dropout', type=str, default=1, help='dropout rate')
54
    # parser.add_argument('--cost_kwargs', type=str, default=1, help='cost_kwargs')
55
    # parser.add_argument('--opt_kwargs', type=str, default=1, help='opt_kwargs')
56
57
    parser.add_argument('--checkpoint_dir', type=str, default='../output/' + train_date + '/checkpoints/' ,
58
                        help='Directory name to save the checkpoints')
59
    parser.add_argument('--result_dir', type=str, default='../output/' + train_date + '/results/',
60
                        help='Directory name to save the generated images')
61
    parser.add_argument('--log_dir', type=str, default='../output/' + train_date + '/logs/',
62
                        help='Directory name to save training logs')
63
    parser.add_argument('--sample_dir', type=str, default='../output/' + train_date + '/samples/',
64
                        help='Directory name to save the samples on training')
65
66
    return check_args(parser.parse_args())
67
68
"""checking arguments"""
69
def check_args(args):
70
    # --checkpoint_dir
71
    check_folder(args.checkpoint_dir)
72
    # --result_dir
73
    check_folder(args.result_dir)
74
    # --result_dir
75
    check_folder(args.log_dir)
76
    # --sample_dir
77
    check_folder(args.sample_dir)
78
79
    return args
80
81
def main():
82
    train_date = 'xxx'
83
    args = parse_args(train_date)
84
85
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
86
87
    # define logger
88
    logging.basicConfig(filename=args.log_dir+"/"+args.phase+'_log.txt', level=logging.DEBUG, format='%(asctime)s %(message)s')
89
    logging.getLogger().addHandler(logging.StreamHandler())
90
91
    # print all parameters
92
    logging.info("Usage:")
93
    logging.info("    {0}".format(" ".join([x for x in sys.argv]))) 
94
    logging.debug("All settings used:")
95
96
    os.system('cp main.py %s' % (args.log_dir)) # bkp of train procedure
97
    os.system('cp saml_func.py %s' % (args.log_dir)) # bkp of train procedure
98
    os.system('cp train.py %s' % (args.log_dir)) # bkp of train procedure
99
    os.system('cp utils.py %s' % (args.log_dir)) # bkp of train procedure
100
    os.system('cp data_generator.py %s' % (args.log_dir))
101
102
103
    filelist_root = '../dataset'
104
    source_list = ['HK', 'ISBI', 'ISBI_1.5', 'I2CVB','UCL', 'BIDMC']#'ISBI_1.5', 'I2CVB', 'UCL','BIDMC']#, 'I2CVB', 'ISBI_1.5', 'UCL', 'BIDMC']#'I2CVB', 'UCL', 'BIDMC', 'HK']
105
    source_list.remove(args.target_domain)
106
107
    # Constructing model
108
    model = SAML(args)
109
    model.construct_model_train()
110
    model.construct_model_test()
111
    
112
    model.summ_op = tf.summary.merge_all()
113
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
114
    sess = tf.InteractiveSession()
115
116
    tf.global_variables_initializer().run()
117
    show_all_variables()
118
119
    # restore model ----
120
    resume_itr = 0
121
    model_file = None
122
    if args.resume:
123
        model_file = tf.train.latest_checkpoint(args.checkpoint_dir)
124
        if model_file:
125
            ind1 = model_file.index('model')
126
            resume_itr = int(model_file[ind1+5:])
127
            print("Restoring model weights from " + model_file)
128
            saver.restore(sess, model_file)
129
130
    train_file_list = [os.path.join(filelist_root, source_domain+'_train_list') for source_domain in source_list]
131
    test_file_list = [os.path.join(filelist_root, args.target_domain+'_train_list')]
132
133
    # start training ----
134
    if args.phase == 'train':
135
        train(model, saver, sess, train_file_list, test_file_list[0], args, resume_itr)
136
    else:
137
        args.test_model = 'xxx'
138
        saver.restore(sess, args.test_model)
139
        logging.info("testing model restored %s" % args.test_model)
140
141
        test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file_list[0], model, args)
142
        with open((os.path.join(args.log_dir,'test.txt')), 'a') as f:
143
            print >> f, 'testing model %s :' % (args.test_model)
144
            print >> f, '   Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr
145
            print >> f, '   Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr
146
147
if __name__ == "__main__":
148
    main()