--- a +++ b/A3C/main.py @@ -0,0 +1,92 @@ +from model import * +import argparse +import sys +import os +worker_threads = [] + +def main(): + + parser = argparse.ArgumentParser(description='Train or test neural net motor controller') + parser.add_argument('--load_model', dest='load_model', action='store_true', default=False) + parser.add_argument('--num_workers', dest='num_workers',action='store',default=1,type=int) + args = parser.parse_args() + + max_episode_length = 1000 + gamma = .995 # discount rate for advantage estimation and reward discounting + s_size = 41 + a_size = 18 # Agent can move Left, Right, or Straight + model_path = './models' + load_model = args.load_model + noisy=False + num_workers = args.num_workers + print(" num_workers = %d" % num_workers) + print(" noisy_net_enabled = %s" % str(noisy)) + print(" load_model = %s" % str(args.load_model)) + + tf.reset_default_graph() + + if not os.path.exists(model_path): + os.makedirs(model_path) + + + with tf.device("/cpu:0"): + global_episodes = tf.Variable(0,dtype=tf.int32,name='global_episodes',trainable=False) + trainer = tf.train.AdamOptimizer(learning_rate=1e-4) + master_network = AC_Network(s_size,a_size,'global',None,noisy) # Generate global network + num_cpu = multiprocessing.cpu_count() # Set workers ot number of available CPU threads + workers = [] + # Create worker classes + for i in range(args.num_workers): + worker = Worker(i,s_size,a_size,trainer,model_path,global_episodes,noisy,is_training= True) + workers.append(worker) + + saver = tf.train.Saver() + + '''networks = ['global'] + ['worker_'+i for i in str(range(num_workers))] + print(networks)''' + #key = print_tensors_in_checkpoint_file('./tmp/checkpoints/mobilenet_v1_0.50_160.ckpt', tensor_name='',all_tensors=True) + #print(key) + + with tf.Session() as sess: + coord = tf.train.Coordinator() + if load_model == True: + print ('Loading Model...') + ckpt = tf.train.get_checkpoint_state(model_path) + saver.restore(sess,ckpt.model_checkpoint_path) + print('loading Model succeeded') + else: + ''' + dict = {} + value = slim.get_model_variables('global'+'/MobilenetV1') + for variable in value: + name = variable.name.replace('global'+'/','').split(':')[0] + #print(name) + if name in key: + dict[name] = variable + #print(dict) + #print(dict) + init_fn = slim.assign_from_checkpoint_fn( + os.path.join(checkpoints_dir, 'mobilenet_v1_0.50_160.ckpt'), + dict) + init_fn(sess)''' + sess.run(tf.global_variables_initializer()) + + # This is where the asynchronous magic happens. + # Start the "work" process for each worker in a separate thread. + + for worker in workers: + worker_work = lambda: worker.work(max_episode_length,gamma,sess,coord,saver) + #worker.start(setting=0,vis=True) + t = threading.Thread(target=(worker_work)) + t.daemon = True + t.start() + worker_threads.append(t) + coord.join(worker_threads) + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("Ctrl-c received! Sending kill to threads...") + for t in worker_threads: + t.kill_received = True