Diff of /A3C/main.py [000000] .. [687a25]

Switch to unified view

a b/A3C/main.py
1
from model import *
2
import argparse
3
import sys
4
import os
5
worker_threads = []
6
7
def main():
8
    
9
    parser = argparse.ArgumentParser(description='Train or test neural net motor controller')
10
    parser.add_argument('--load_model', dest='load_model', action='store_true', default=False)
11
    parser.add_argument('--num_workers', dest='num_workers',action='store',default=1,type=int)
12
    args = parser.parse_args()
13
14
    max_episode_length = 1000
15
    gamma = .995 # discount rate for advantage estimation and reward discounting
16
    s_size = 41
17
    a_size = 18 # Agent can move Left, Right, or Straight
18
    model_path = './models'
19
    load_model = args.load_model
20
    noisy=False
21
    num_workers = args.num_workers
22
    print(" num_workers = %d" % num_workers)
23
    print(" noisy_net_enabled = %s" % str(noisy))
24
    print(" load_model = %s" % str(args.load_model))
25
26
    tf.reset_default_graph()
27
28
    if not os.path.exists(model_path):
29
        os.makedirs(model_path)
30
        
31
32
    with tf.device("/cpu:0"): 
33
        global_episodes = tf.Variable(0,dtype=tf.int32,name='global_episodes',trainable=False)
34
        trainer = tf.train.AdamOptimizer(learning_rate=1e-4)
35
        master_network = AC_Network(s_size,a_size,'global',None,noisy) # Generate global network
36
        num_cpu = multiprocessing.cpu_count() # Set workers ot number of available CPU threads
37
        workers = []
38
            # Create worker classes
39
        for i in range(args.num_workers):
40
            worker = Worker(i,s_size,a_size,trainer,model_path,global_episodes,noisy,is_training= True)
41
            workers.append(worker)
42
43
        saver = tf.train.Saver()
44
        
45
    '''networks = ['global'] + ['worker_'+i for i in str(range(num_workers))]
46
    print(networks)'''
47
    #key = print_tensors_in_checkpoint_file('./tmp/checkpoints/mobilenet_v1_0.50_160.ckpt', tensor_name='',all_tensors=True)
48
    #print(key)
49
    
50
    with tf.Session() as sess:
51
        coord = tf.train.Coordinator()
52
        if load_model == True:
53
            print ('Loading Model...')
54
            ckpt = tf.train.get_checkpoint_state(model_path)
55
            saver.restore(sess,ckpt.model_checkpoint_path)
56
        print('loading Model succeeded')
57
        else:
58
            '''
59
            dict = {}
60
            value = slim.get_model_variables('global'+'/MobilenetV1')
61
            for variable in value:
62
                name = variable.name.replace('global'+'/','').split(':')[0]
63
                    #print(name)
64
                if name in key:
65
                    dict[name] = variable
66
                #print(dict)
67
                #print(dict)
68
            init_fn = slim.assign_from_checkpoint_fn(
69
                                os.path.join(checkpoints_dir, 'mobilenet_v1_0.50_160.ckpt'),
70
                                dict)
71
            init_fn(sess)'''
72
            sess.run(tf.global_variables_initializer())
73
            
74
        # This is where the asynchronous magic happens.
75
        # Start the "work" process for each worker in a separate thread.
76
        
77
        for worker in workers:
78
            worker_work = lambda: worker.work(max_episode_length,gamma,sess,coord,saver)
79
        #worker.start(setting=0,vis=True)
80
            t = threading.Thread(target=(worker_work))
81
            t.daemon = True
82
            t.start()
83
            worker_threads.append(t)
84
        coord.join(worker_threads)
85
        
86
if __name__ == "__main__":
87
    try:
88
        main()
89
    except KeyboardInterrupt:
90
        print("Ctrl-c received! Sending kill to threads...")
91
        for t in worker_threads:
92
            t.kill_received = True