|
a |
|
b/main.py |
|
|
1 |
import sys |
|
|
2 |
import os |
|
|
3 |
import numpy as np |
|
|
4 |
import tensorflow as tf |
|
|
5 |
from tensorflow.python.platform import flags |
|
|
6 |
from data_generator import ImageDataGenerator |
|
|
7 |
from saml_func import SAML |
|
|
8 |
from train import train |
|
|
9 |
from train import test |
|
|
10 |
import datetime |
|
|
11 |
import argparse |
|
|
12 |
from utils import check_folder, show_all_variables |
|
|
13 |
import logging |
|
|
14 |
|
|
|
15 |
currtime = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') |
|
|
16 |
tf.set_random_seed(2) |
|
|
17 |
|
|
|
18 |
def parse_args(train_date): |
|
|
19 |
desc = "Tensorflow implementation of DenseUNet for prostate segmentation" |
|
|
20 |
parser = argparse.ArgumentParser(description=desc) |
|
|
21 |
parser.add_argument('--gpu', type=str, default='0', help='train or test or guide') |
|
|
22 |
parser.add_argument('--phase', type=str, default='train', help='train or test or guide') |
|
|
23 |
parser.add_argument('--n_class', type=int, default=2, help='The size of class') |
|
|
24 |
|
|
|
25 |
## Training operations |
|
|
26 |
parser.add_argument('--target_domain', type=str, default='ISBI', help='dataset_name') |
|
|
27 |
parser.add_argument('--volume_size', type=list, default=[384, 384, 3], help='The size of input data') |
|
|
28 |
parser.add_argument('--label_size', type=list, default=[384, 384, 1], help='The size of label') |
|
|
29 |
parser.add_argument('--epoch', type=int, default=1, help='The number of epochs to run') |
|
|
30 |
parser.add_argument('--train_iterations', type=int, default=10000, help='The number of training iterations') |
|
|
31 |
parser.add_argument('--meta_batch_size', type=int, default=5, help='number of images sampled per source domain') |
|
|
32 |
parser.add_argument('--test_batch_size', type=int, default=1, help='number of images sampled per source domain') |
|
|
33 |
parser.add_argument('--inner_lr', type=float, default=1e-4, help='The learning rate') |
|
|
34 |
parser.add_argument('--outer_lr', type=float, default=1e-3, help='The learning rate') |
|
|
35 |
parser.add_argument('--metric_lr', type=float, default=1e-3, help='The learning rate') |
|
|
36 |
parser.add_argument('--margin', type=float, default=10.0, help='The learning rate') |
|
|
37 |
parser.add_argument('--compactness_loss_weight', type=float, default=1.0, help='The learning rate') |
|
|
38 |
parser.add_argument('--smoothness_loss_weight', type=float, default=0.005, help='The learning rate') |
|
|
39 |
parser.add_argument('--clipNorm', type=int, default=True, help='number of images sampled per source domain') |
|
|
40 |
parser.add_argument('--gradients_clip_value', type=float, default=10.0, help='The learning rate') |
|
|
41 |
|
|
|
42 |
# Logging, saving, and testing options |
|
|
43 |
parser.add_argument('--resume', type=int, default=False, help='number of images sampled per source domain') |
|
|
44 |
parser.add_argument('--log', type=int, default=True, help='write tensorboard') |
|
|
45 |
parser.add_argument('--decay_step', type=float, default=500, help='The learning rate') |
|
|
46 |
parser.add_argument('--decay_rate', type=float, default=0.95, help='The learning rate') |
|
|
47 |
parser.add_argument('--test_freq', type=int, default=200, help='The number of ckpt_save_freq') |
|
|
48 |
parser.add_argument('--save_freq', type=int, default=200, help='The number of ckpt_save_freq') |
|
|
49 |
parser.add_argument('--print_interval', type=int, default=5, help='The frequency to write tensorboard') |
|
|
50 |
parser.add_argument('--summary_interval', type=int, default=20, help='The frequency to write tensorboard') |
|
|
51 |
parser.add_argument('--restored_model', type=str, default=None, help='Model to restore') |
|
|
52 |
parser.add_argument('--test_model', type=str, default=None, help='Model to restore') |
|
|
53 |
# parser.add_argument('--dropout', type=str, default=1, help='dropout rate') |
|
|
54 |
# parser.add_argument('--cost_kwargs', type=str, default=1, help='cost_kwargs') |
|
|
55 |
# parser.add_argument('--opt_kwargs', type=str, default=1, help='opt_kwargs') |
|
|
56 |
|
|
|
57 |
parser.add_argument('--checkpoint_dir', type=str, default='../output/' + train_date + '/checkpoints/' , |
|
|
58 |
help='Directory name to save the checkpoints') |
|
|
59 |
parser.add_argument('--result_dir', type=str, default='../output/' + train_date + '/results/', |
|
|
60 |
help='Directory name to save the generated images') |
|
|
61 |
parser.add_argument('--log_dir', type=str, default='../output/' + train_date + '/logs/', |
|
|
62 |
help='Directory name to save training logs') |
|
|
63 |
parser.add_argument('--sample_dir', type=str, default='../output/' + train_date + '/samples/', |
|
|
64 |
help='Directory name to save the samples on training') |
|
|
65 |
|
|
|
66 |
return check_args(parser.parse_args()) |
|
|
67 |
|
|
|
68 |
"""checking arguments""" |
|
|
69 |
def check_args(args): |
|
|
70 |
# --checkpoint_dir |
|
|
71 |
check_folder(args.checkpoint_dir) |
|
|
72 |
# --result_dir |
|
|
73 |
check_folder(args.result_dir) |
|
|
74 |
# --result_dir |
|
|
75 |
check_folder(args.log_dir) |
|
|
76 |
# --sample_dir |
|
|
77 |
check_folder(args.sample_dir) |
|
|
78 |
|
|
|
79 |
return args |
|
|
80 |
|
|
|
81 |
def main(): |
|
|
82 |
train_date = 'xxx' |
|
|
83 |
args = parse_args(train_date) |
|
|
84 |
|
|
|
85 |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu |
|
|
86 |
|
|
|
87 |
# define logger |
|
|
88 |
logging.basicConfig(filename=args.log_dir+"/"+args.phase+'_log.txt', level=logging.DEBUG, format='%(asctime)s %(message)s') |
|
|
89 |
logging.getLogger().addHandler(logging.StreamHandler()) |
|
|
90 |
|
|
|
91 |
# print all parameters |
|
|
92 |
logging.info("Usage:") |
|
|
93 |
logging.info(" {0}".format(" ".join([x for x in sys.argv]))) |
|
|
94 |
logging.debug("All settings used:") |
|
|
95 |
|
|
|
96 |
os.system('cp main.py %s' % (args.log_dir)) # bkp of train procedure |
|
|
97 |
os.system('cp saml_func.py %s' % (args.log_dir)) # bkp of train procedure |
|
|
98 |
os.system('cp train.py %s' % (args.log_dir)) # bkp of train procedure |
|
|
99 |
os.system('cp utils.py %s' % (args.log_dir)) # bkp of train procedure |
|
|
100 |
os.system('cp data_generator.py %s' % (args.log_dir)) |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
filelist_root = '../dataset' |
|
|
104 |
source_list = ['HK', 'ISBI', 'ISBI_1.5', 'I2CVB','UCL', 'BIDMC']#'ISBI_1.5', 'I2CVB', 'UCL','BIDMC']#, 'I2CVB', 'ISBI_1.5', 'UCL', 'BIDMC']#'I2CVB', 'UCL', 'BIDMC', 'HK'] |
|
|
105 |
source_list.remove(args.target_domain) |
|
|
106 |
|
|
|
107 |
# Constructing model |
|
|
108 |
model = SAML(args) |
|
|
109 |
model.construct_model_train() |
|
|
110 |
model.construct_model_test() |
|
|
111 |
|
|
|
112 |
model.summ_op = tf.summary.merge_all() |
|
|
113 |
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)) |
|
|
114 |
sess = tf.InteractiveSession() |
|
|
115 |
|
|
|
116 |
tf.global_variables_initializer().run() |
|
|
117 |
show_all_variables() |
|
|
118 |
|
|
|
119 |
# restore model ---- |
|
|
120 |
resume_itr = 0 |
|
|
121 |
model_file = None |
|
|
122 |
if args.resume: |
|
|
123 |
model_file = tf.train.latest_checkpoint(args.checkpoint_dir) |
|
|
124 |
if model_file: |
|
|
125 |
ind1 = model_file.index('model') |
|
|
126 |
resume_itr = int(model_file[ind1+5:]) |
|
|
127 |
print("Restoring model weights from " + model_file) |
|
|
128 |
saver.restore(sess, model_file) |
|
|
129 |
|
|
|
130 |
train_file_list = [os.path.join(filelist_root, source_domain+'_train_list') for source_domain in source_list] |
|
|
131 |
test_file_list = [os.path.join(filelist_root, args.target_domain+'_train_list')] |
|
|
132 |
|
|
|
133 |
# start training ---- |
|
|
134 |
if args.phase == 'train': |
|
|
135 |
train(model, saver, sess, train_file_list, test_file_list[0], args, resume_itr) |
|
|
136 |
else: |
|
|
137 |
args.test_model = 'xxx' |
|
|
138 |
saver.restore(sess, args.test_model) |
|
|
139 |
logging.info("testing model restored %s" % args.test_model) |
|
|
140 |
|
|
|
141 |
test_dice, test_dice_arr, test_haus, test_haus_arr = test(sess, test_file_list[0], model, args) |
|
|
142 |
with open((os.path.join(args.log_dir,'test.txt')), 'a') as f: |
|
|
143 |
print >> f, 'testing model %s :' % (args.test_model) |
|
|
144 |
print >> f, ' Unseen domain testing results: Dice: %f' %(test_dice), test_dice_arr |
|
|
145 |
print >> f, ' Unseen domain testing results: Haus: %f' %(test_haus), test_haus_arr |
|
|
146 |
|
|
|
147 |
if __name__ == "__main__": |
|
|
148 |
main() |