Switch to side-by-side view

--- a
+++ b/RefineNet & SESNet/SESNet/multi_gpu_train.py
@@ -0,0 +1,249 @@
+import tensorflow as tf
+from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.framework import ops as _ops
+import time
+import shutil
+import datetime
+import os
+import cv2
+import pickle
+import numpy as np
+from tensorflow.contrib import slim
+import sys
+
+sys.path.append(os.getcwd())
+from nets import model as model
+from utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors
+from utils.pascal_voc import pascal_segmentation_lut
+
+# Parameter Setting: batch_size 2 gpu_list 2
+
+tf.app.flags.DEFINE_string('model_type', 'refinenet', 'refinenet or sesnet')
+tf.app.flags.DEFINE_integer('batch_size', 2, '')
+tf.app.flags.DEFINE_integer('train_size', 512, '')
+tf.app.flags.DEFINE_float('learning_rate', 0.00001, '')
+tf.app.flags.DEFINE_integer('max_steps', 60000, '')
+tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '')
+tf.app.flags.DEFINE_integer('num_classes', 2, '')
+tf.app.flags.DEFINE_string('gpu_list', '0,1', '')
+tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints/', '')
+tf.app.flags.DEFINE_string('logs_path', 'logs/', '')
+tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
+tf.app.flags.DEFINE_integer('save_checkpoint_steps', 500, '')
+tf.app.flags.DEFINE_integer('save_summary_steps', 10, '')
+tf.app.flags.DEFINE_integer('save_image_steps', 10, '')
+tf.app.flags.DEFINE_string('training_data_path', 'data/train.tfrecords', '')
+tf.app.flags.DEFINE_string('pretrained_model_path', 'data/resnet_v1_101.ckpt', '')
+tf.app.flags.DEFINE_integer('decay_steps', 20000, '')
+tf.app.flags.DEFINE_float('decay_rate', 0.1, '')
+FLAGS = tf.app.flags.FLAGS
+
+
+def tower_loss(images, annotation, class_labels, reuse_variables=None):
+	with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
+		logits = model.model(FLAGS.model_type, images, is_training=True)
+	pred = tf.argmax(logits, dimension=3)
+
+	model_loss = model.loss(annotation, logits, class_labels)
+	total_loss = tf.add_n([model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+
+	# add summary
+	if reuse_variables is None:
+		tf.summary.scalar('model_loss', model_loss)
+		tf.summary.scalar('total_loss', total_loss)
+	return total_loss, model_loss, pred, logits
+
+
+def average_gradients(tower_grads):
+	average_grads = []
+	for grad_and_vars in zip(*tower_grads):
+		grads = []
+		for g, _ in grad_and_vars:
+			expanded_g = tf.expand_dims(g, 0)
+			grads.append(expanded_g)
+		grad = tf.concat(grads, 0)
+		grad = tf.reduce_mean(grad, 0)
+		v = grad_and_vars[0][1]
+		grad_and_var = (grad, v)
+		average_grads.append(grad_and_var)
+	return average_grads
+
+
+def build_image_summary():
+	log_image_data = tf.placeholder(tf.uint8, [None, None, 3])
+	log_image_name = tf.placeholder(tf.string)
+	log_image = gen_logging_ops .image_summary(log_image_name, tf.expand_dims(log_image_data, 0), max_images=1)
+	_ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image)
+	return log_image, log_image_data, log_image_name
+
+
+def main(argv=None):
+	gpus = range(len(FLAGS.gpu_list.split(',')))
+	pascal_voc_lut = pascal_segmentation_lut()
+	class_labels = pascal_voc_lut.keys()
+	print(class_labels)
+	with open('data/color_map', 'rb') as f:
+		color_map = pickle.load(f)
+
+	os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
+	now = datetime.datetime.now()
+	StyleTime = now.strftime("%Y-%m-%d-%H-%M-%S")
+	os.makedirs(FLAGS.logs_path + StyleTime)
+	if not os.path.exists(FLAGS.checkpoint_path):
+		os.makedirs(FLAGS.checkpoint_path)
+	else:
+		if not FLAGS.restore:
+			if os.path.exists(FLAGS.checkpoint_path):
+				shutil.rmtree(FLAGS.checkpoint_path)
+				os.makedirs(FLAGS.checkpoint_path)
+
+	filename_queue = tf.train.string_input_producer([FLAGS.training_data_path], num_epochs=1000)
+	image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue)
+
+	image_train_size = [FLAGS.train_size, FLAGS.train_size]
+	annotation_train_size = [FLAGS.train_size // 4, FLAGS.train_size // 4]
+
+	resized_image = tf.image.resize_images(image, image_train_size, method=1)
+	resized_annotation = tf.image.resize_images(annotation, annotation_train_size, method=1)
+	resized_annotation = tf.squeeze(resized_annotation)
+
+	image_batch, annotation_batch = tf.train.shuffle_batch([resized_image, resized_annotation],
+	                                                       batch_size=FLAGS.batch_size * len(gpus), capacity=1000,
+	                                                       num_threads=4,
+	                                                       min_after_dequeue=500)
+
+	# split
+	input_images_split = tf.split(image_batch, len(gpus))
+	input_segs_split = tf.split(annotation_batch, len(gpus))
+
+	learning_rate = tf.Variable(FLAGS.learning_rate, trainable=False)
+	global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
+	# add summary
+	tf.summary.scalar('learning_rate', learning_rate)
+	opt = tf.train.AdamOptimizer(learning_rate)
+
+	tower_grads = []
+	reuse_variables = None
+	for i, gpu_id in enumerate(gpus):
+		with tf.device('/gpu:%d' % gpu_id):
+			with tf.name_scope('model_%d' % gpu_id) as scope:
+				iis = input_images_split[i]
+				isms = input_segs_split[i]
+				total_loss, model_loss, output_pred, output_logits = tower_loss(iis, isms, class_labels, reuse_variables)
+				batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
+				reuse_variables = True
+				grads = opt.compute_gradients(total_loss)
+				tower_grads.append(grads)
+
+	grads = average_gradients(tower_grads)
+	apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+	summary_op = tf.summary.merge_all()
+	log_image, log_image_data, log_image_name = build_image_summary()
+	# save moving average
+	variable_averages = tf.train.ExponentialMovingAverage(
+		FLAGS.moving_average_decay, global_step)
+	variables_averages_op = variable_averages.apply(tf.trainable_variables())
+	# batch norm updates
+	with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
+		train_op = tf.no_op(name='train_op')
+
+	saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
+	summary_writer = tf.summary.FileWriter(FLAGS.logs_path + StyleTime, tf.get_default_graph())
+
+	if FLAGS.pretrained_model_path is not None:
+		variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
+		                                                     slim.get_trainable_variables(),
+		                                                     ignore_missing_vars=True)
+
+	global_vars_init_op = tf.global_variables_initializer()
+	local_vars_init_op = tf.local_variables_initializer()
+	init = tf.group(local_vars_init_op, global_vars_init_op)
+
+	gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
+	config = tf.ConfigProto(allow_soft_placement=True,
+	                        log_device_placement=False,
+	                        gpu_options=gpu_options)
+	config.gpu_options.allow_growth = True
+
+	with tf.Session(config=config) as sess:
+		restore_step = 0
+		if FLAGS.restore:
+			sess.run(init)
+			print('continue training from previous checkpoint')
+			ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
+			restore_step = int(ckpt.split('.')[0].split('_')[-1])
+			saver.restore(sess, ckpt)
+		else:
+			sess.run(init)
+			if FLAGS.pretrained_model_path is not None:
+				variable_restore_op(sess)
+
+		start = time.time()
+		coord = tf.train.Coordinator()
+		threads = tf.train.start_queue_runners(coord=coord)
+		try:
+			while not coord.should_stop():
+				for step in range(restore_step, FLAGS.max_steps):
+					if step != 0 and step % FLAGS.decay_steps == 0:
+						sess.run(tf.assign(learning_rate, learning_rate.eval() * FLAGS.decay_rate))
+
+					ml, tl, _ = sess.run([model_loss, total_loss, train_op])
+					if np.isnan(tl):
+						print('Loss diverged, stop training')
+						break
+					if step % 10 == 0:
+						avg_time_per_step = (time.time() - start) / 10
+						start = time.time()
+						print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.3f} seconds/step, lr: {:.7f}'). \
+							format(step, ml, tl, avg_time_per_step, learning_rate.eval())
+
+					if (step + 1) % FLAGS.save_checkpoint_steps == 0:
+						filename = ('SESNet' + '_step_{:d}'.format(step + 1) + '.ckpt')
+						filename = os.path.join(FLAGS.checkpoint_path, filename)
+						saver.save(sess, filename)
+						print('Write model to: {:s}'.format(filename))
+
+					if step % FLAGS.save_summary_steps == 0:
+						_, tl, summary_str = sess.run([train_op, total_loss, summary_op])
+						summary_writer.add_summary(summary_str, global_step=step)
+
+					if step % FLAGS.save_image_steps == 0:
+						log_image_name_str = ('%06d' % step)
+						img_split, seg_split, pred = sess.run([iis, isms, output_pred])
+
+						img_split = np.squeeze(img_split)[0]
+						seg_split = np.squeeze(seg_split)[0]
+						pred = np.squeeze(pred)[0]  # why cannot batch_size = 1
+
+						# img_split = img_split[0]
+						# seg_split = seg_split[0]
+						# pred = pred[0]
+
+						# print(np.max(seg_split))
+
+						img_split = cv2.resize(img_split,(128,128))
+
+						color_seg = np.zeros((seg_split.shape[0], seg_split.shape[1], 3))
+						for i in range(seg_split.shape[0]):
+							for j in range(seg_split.shape[1]):
+								color_seg[i, j, :] = color_map[str(seg_split[i][j])]
+
+						color_pred = np.zeros((pred.shape[0], pred.shape[1], 3))
+						for i in range(pred.shape[0]):
+							for j in range(pred.shape[1]):
+								color_pred[i, j, :] = color_map[str(class_labels[pred[i][j]])]
+
+						write_img = np.hstack((img_split, color_seg, color_pred))
+						log_image_summary_op = sess.run(log_image, feed_dict={log_image_name: log_image_name_str, \
+						                                                      log_image_data: write_img})
+						summary_writer.add_summary(log_image_summary_op, global_step=step)
+		except tf.errors.OutOfRangeError:
+			print('finish')
+		finally:
+			coord.request_stop()
+		coord.join(threads)
+
+
+if __name__ == '__main__':
+	tf.app.run()