|
a |
|
b/ADDPG/main.py |
|
|
1 |
from model import * |
|
|
2 |
import sys |
|
|
3 |
import os |
|
|
4 |
import multiprocessing |
|
|
5 |
import threading |
|
|
6 |
import argparse |
|
|
7 |
worker_threads = [] |
|
|
8 |
|
|
|
9 |
def main(): |
|
|
10 |
|
|
|
11 |
parser = argparse.ArgumentParser(description='Train or test neural net motor controller') |
|
|
12 |
parser.add_argument('--load_model', dest='load_model', action='store_true', default=False) |
|
|
13 |
parser.add_argument('--num_workers', dest='num_workers',action='store',default=3,type=int) |
|
|
14 |
parser.add_argument('--visualize', dest='vis', action='store_true', default=False) |
|
|
15 |
args = parser.parse_args() |
|
|
16 |
|
|
|
17 |
load_model = args.load_model |
|
|
18 |
num_workers = args.num_workers |
|
|
19 |
vis = args.vis |
|
|
20 |
training = True#not load_model |
|
|
21 |
model_path = './models' |
|
|
22 |
|
|
|
23 |
if not os.path.exists(model_path): |
|
|
24 |
os.makedirs(model_path) |
|
|
25 |
|
|
|
26 |
# hyperparameters |
|
|
27 |
explore = 1000 |
|
|
28 |
batch_size = 32 |
|
|
29 |
gamma = 0.995 |
|
|
30 |
n_step = 3 |
|
|
31 |
|
|
|
32 |
tf.reset_default_graph() |
|
|
33 |
|
|
|
34 |
with tf.Session() as sess: |
|
|
35 |
with tf.device("/cpu:0"): |
|
|
36 |
global_episodes = tf.Variable(0,dtype=tf.int32,name='global_episodes',trainable=False) |
|
|
37 |
global_actor_network = ActorNetwork(sess,41+14+3,18,'global'+'/actor') |
|
|
38 |
num_cpu = multiprocessing.cpu_count() # Set workers ot number of available CPU threads |
|
|
39 |
workers = [] |
|
|
40 |
# Create worker classes |
|
|
41 |
for i in range(num_workers): |
|
|
42 |
worker = Worker(sess,i,model_path,global_episodes,explore,training,vis,batch_size,gamma,n_step,global_actor_network.net) |
|
|
43 |
workers.append(worker) |
|
|
44 |
saver = tf.train.Saver() |
|
|
45 |
|
|
|
46 |
coord = tf.train.Coordinator() |
|
|
47 |
if load_model == True: |
|
|
48 |
print ('Loading Model...') |
|
|
49 |
ckpt = tf.train.get_checkpoint_state(model_path) |
|
|
50 |
saver.restore(sess,ckpt.model_checkpoint_path) |
|
|
51 |
print ('Loading Model succeeded...') |
|
|
52 |
else: |
|
|
53 |
sess.run(tf.global_variables_initializer()) |
|
|
54 |
|
|
|
55 |
# This is where the asynchronous magic happens. |
|
|
56 |
# Start the "work" process for each worker in a separate thread. |
|
|
57 |
|
|
|
58 |
for worker in workers: |
|
|
59 |
worker_work = lambda: worker.work(coord,saver) |
|
|
60 |
t = threading.Thread(target=(worker_work)) |
|
|
61 |
t.daemon = True |
|
|
62 |
t.start() |
|
|
63 |
worker_threads.append(t) |
|
|
64 |
sleep(0.05) |
|
|
65 |
coord.join(worker_threads) |
|
|
66 |
|
|
|
67 |
if __name__ == "__main__": |
|
|
68 |
try: |
|
|
69 |
main() |
|
|
70 |
except KeyboardInterrupt: |
|
|
71 |
print("Ctrl-c received! Sending kill to threads...") |
|
|
72 |
for t in worker_threads: |
|
|
73 |
t.kill_received = True |