Diff of /ADDPG/main.py [000000] .. [687a25]

Switch to unified view

a b/ADDPG/main.py
1
from model import *
2
import sys
3
import os
4
import multiprocessing
5
import threading
6
import argparse
7
worker_threads = []
8
9
def main():
10
11
    parser = argparse.ArgumentParser(description='Train or test neural net motor controller')
12
    parser.add_argument('--load_model', dest='load_model', action='store_true', default=False)
13
    parser.add_argument('--num_workers', dest='num_workers',action='store',default=3,type=int)
14
    parser.add_argument('--visualize', dest='vis', action='store_true', default=False)
15
    args = parser.parse_args()
16
17
    load_model = args.load_model
18
    num_workers = args.num_workers
19
    vis = args.vis
20
    training = True#not load_model
21
    model_path = './models'
22
23
    if not os.path.exists(model_path):
24
        os.makedirs(model_path)
25
    
26
    # hyperparameters
27
    explore = 1000
28
    batch_size = 32
29
    gamma = 0.995
30
    n_step = 3
31
        
32
    tf.reset_default_graph()
33
        
34
    with tf.Session() as sess:
35
        with tf.device("/cpu:0"): 
36
            global_episodes = tf.Variable(0,dtype=tf.int32,name='global_episodes',trainable=False)
37
            global_actor_network = ActorNetwork(sess,41+14+3,18,'global'+'/actor')
38
            num_cpu = multiprocessing.cpu_count() # Set workers ot number of available CPU threads
39
            workers = []
40
            # Create worker classes
41
            for i in range(num_workers):
42
                worker = Worker(sess,i,model_path,global_episodes,explore,training,vis,batch_size,gamma,n_step,global_actor_network.net)
43
            workers.append(worker)
44
            saver = tf.train.Saver()
45
46
            coord = tf.train.Coordinator()
47
            if load_model == True:
48
                print ('Loading Model...')
49
                ckpt = tf.train.get_checkpoint_state(model_path)
50
                saver.restore(sess,ckpt.model_checkpoint_path)
51
                print ('Loading Model succeeded...')
52
            else:
53
                sess.run(tf.global_variables_initializer())
54
                
55
            # This is where the asynchronous magic happens.
56
            # Start the "work" process for each worker in a separate thread.
57
            
58
            for worker in workers:
59
                worker_work = lambda: worker.work(coord,saver)
60
                t = threading.Thread(target=(worker_work))
61
                t.daemon = True
62
                t.start()
63
                worker_threads.append(t)
64
                sleep(0.05)
65
            coord.join(worker_threads)
66
        
67
if __name__ == "__main__":
68
    try:
69
        main()
70
    except KeyboardInterrupt:
71
        print("Ctrl-c received! Sending kill to threads...")
72
        for t in worker_threads:
73
            t.kill_received = True