|
a |
|
b/simulate.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
from SAC.sac import SAC_Agent |
|
|
4 |
from SAC.replay_memory import PolicyReplayMemory |
|
|
5 |
from SAC.RL_Framework_Mujoco import Muscle_Env |
|
|
6 |
from SAC import sensory_feedback_specs, kinematics_preprocessing_specs, perturbation_specs |
|
|
7 |
import pickle |
|
|
8 |
import os |
|
|
9 |
from numpy.core.records import fromarrays |
|
|
10 |
from scipy.io import savemat |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
#Set the current working directory |
|
|
14 |
os.chdir(os.getcwd()) |
|
|
15 |
|
|
|
16 |
class Simulate(): |
|
|
17 |
|
|
|
18 |
def __init__(self, env:Muscle_Env, args): |
|
|
19 |
|
|
|
20 |
"""Train a soft actor critic agent to control a musculoskeletal model to follow a kinematic trajectory. |
|
|
21 |
|
|
|
22 |
Parameters |
|
|
23 |
---------- |
|
|
24 |
env: str |
|
|
25 |
specify which environment (model) to train on |
|
|
26 |
model: str |
|
|
27 |
specify whether to use an rnn or a gated rnn (gru) |
|
|
28 |
optimizer_spec: float |
|
|
29 |
gamma discount parameter in sac algorithm |
|
|
30 |
tau: float |
|
|
31 |
tau parameter for soft updates of the critic and critic target networks |
|
|
32 |
lr: float |
|
|
33 |
learning rate of the critic and actor networks |
|
|
34 |
alpha: float |
|
|
35 |
entropy term used in the sac policy update |
|
|
36 |
automatic entropy tuning: bool |
|
|
37 |
whether to use automatic entropy tuning during training |
|
|
38 |
seed: int |
|
|
39 |
seed for the environment and networks |
|
|
40 |
policy_batch_size: int |
|
|
41 |
the number of episodes used in a batch during training |
|
|
42 |
hidden_size: int |
|
|
43 |
number of hidden neurons in the actor and critic |
|
|
44 |
policy_replay_size: int |
|
|
45 |
number of episodes to store in the replay |
|
|
46 |
multi_policy_loss: bool |
|
|
47 |
use additional policy losses during updates (reduce norm of weights) |
|
|
48 |
ONLY USE WITH RNN, NOT IMPLEMENTED WITH GATING |
|
|
49 |
batch_iters: int |
|
|
50 |
How many experience replay rounds (not steps!) to perform between |
|
|
51 |
each update |
|
|
52 |
cuda: bool |
|
|
53 |
use cuda gpu |
|
|
54 |
visualize: bool |
|
|
55 |
visualize model |
|
|
56 |
model_save_name: str |
|
|
57 |
specify the name of the model to save |
|
|
58 |
episodes: int |
|
|
59 |
total number of episodes to run |
|
|
60 |
save_iter: int |
|
|
61 |
number of iterations before saving the model |
|
|
62 |
checkpoint_path: str |
|
|
63 |
specify path to save and load model |
|
|
64 |
muscle_path: str |
|
|
65 |
path for the musculoskeletal model |
|
|
66 |
muscle_params_path: str |
|
|
67 |
path for musculoskeletal model parameters |
|
|
68 |
""" |
|
|
69 |
|
|
|
70 |
### TRAINING VARIABLES ### |
|
|
71 |
self.episodes = args.total_episodes |
|
|
72 |
self.hidden_size = args.hidden_size |
|
|
73 |
self.policy_batch_size = args.policy_batch_size |
|
|
74 |
self.visualize = args.visualize |
|
|
75 |
self.root_dir = args.root_dir |
|
|
76 |
self.checkpoint_file = args.checkpoint_file |
|
|
77 |
self.checkpoint_folder = args.checkpoint_folder |
|
|
78 |
self.statistics_folder = args.statistics_folder |
|
|
79 |
self.batch_iters = args.batch_iters |
|
|
80 |
self.save_iter = args.save_iter |
|
|
81 |
self.mode_to_sim = args.mode |
|
|
82 |
self.condition_selection_strategy = args.condition_selection_strategy |
|
|
83 |
self.load_saved_nets_for_training = args.load_saved_nets_for_training |
|
|
84 |
self.verbose_training = args.verbose_training |
|
|
85 |
|
|
|
86 |
### ENSURE SAVING FILES ARE ACCURATE ### |
|
|
87 |
assert isinstance(self.root_dir, str) |
|
|
88 |
assert isinstance(self.checkpoint_folder, str) |
|
|
89 |
assert isinstance(self.checkpoint_file, str) |
|
|
90 |
|
|
|
91 |
### SEED ### |
|
|
92 |
torch.manual_seed(args.seed) |
|
|
93 |
np.random.seed(args.seed) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
### LOAD CUSTOM GYM ENVIRONMENT ### |
|
|
97 |
if self.mode_to_sim in ["musculo_properties"]: |
|
|
98 |
self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets_pert.xml', 1, args) |
|
|
99 |
|
|
|
100 |
else: |
|
|
101 |
self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets.xml', 1, args) |
|
|
102 |
|
|
|
103 |
self.observation_shape = self.env.observation_space.shape[0]+len(self.env.sfs_visual_velocity)*3+1 |
|
|
104 |
|
|
|
105 |
### SAC AGENT ### |
|
|
106 |
self.agent = SAC_Agent(self.observation_shape, |
|
|
107 |
self.env.action_space, |
|
|
108 |
args.hidden_size, |
|
|
109 |
args.lr, |
|
|
110 |
args.gamma, |
|
|
111 |
args.tau, |
|
|
112 |
args.alpha, |
|
|
113 |
args.automatic_entropy_tuning, |
|
|
114 |
args.model, |
|
|
115 |
args.multi_policy_loss, |
|
|
116 |
args.alpha_usim, |
|
|
117 |
args.beta_usim, |
|
|
118 |
args.gamma_usim, |
|
|
119 |
args.zeta_nusim, |
|
|
120 |
args.cuda) |
|
|
121 |
|
|
|
122 |
### REPLAY MEMORY ### |
|
|
123 |
self.policy_memory = PolicyReplayMemory(args.policy_replay_size, args.seed) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def test(self, save_name): |
|
|
127 |
|
|
|
128 |
""" Use a saved model to generate kinematic trajectory |
|
|
129 |
|
|
|
130 |
saves these values: |
|
|
131 |
-------- |
|
|
132 |
episode_reward: int |
|
|
133 |
- Total reward of the trained model |
|
|
134 |
kinematics: list |
|
|
135 |
- Kinematics of the model during testing |
|
|
136 |
hidden_activity: list |
|
|
137 |
- list containing the rnn activity during testing |
|
|
138 |
""" |
|
|
139 |
|
|
|
140 |
#Update the environment kinematics to both the training and testing conditions |
|
|
141 |
self.env.update_kinematics_for_test() |
|
|
142 |
|
|
|
143 |
### LOAD SAVED MODEL ### |
|
|
144 |
self.load_saved_nets_from_checkpoint(load_best = True) |
|
|
145 |
|
|
|
146 |
#Set the recurrent connections to zero if the mode is SFE |
|
|
147 |
if self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim: |
|
|
148 |
self.agent.actor.rnn.weight_hh_l0 = torch.nn.Parameter(self.agent.actor.rnn.weight_hh_l0 * 0) |
|
|
149 |
|
|
|
150 |
### TESTING DATA ### |
|
|
151 |
Test_Data = { |
|
|
152 |
"emg": {}, |
|
|
153 |
"rnn_activity": {}, |
|
|
154 |
"rnn_input": {}, |
|
|
155 |
"rnn_input_fp": {}, |
|
|
156 |
"kinematics_mbodies": {}, |
|
|
157 |
"kinematics_mtargets": {} |
|
|
158 |
} |
|
|
159 |
|
|
|
160 |
###Data jPCA |
|
|
161 |
activity_jpca = [] |
|
|
162 |
times_jpca = [] |
|
|
163 |
n_fixedsteps_jpca = [] |
|
|
164 |
condition_tpoints_jpca = [] |
|
|
165 |
|
|
|
166 |
#Save the following after testing: EMG, RNN_activity, RNN_input, kin_musculo_bodies, kin_musculo_targets |
|
|
167 |
emg = {} |
|
|
168 |
rnn_activity = {} |
|
|
169 |
rnn_input = {} |
|
|
170 |
kin_mb = {} |
|
|
171 |
kin_mt = {} |
|
|
172 |
rnn_input_fp = {} |
|
|
173 |
|
|
|
174 |
for i_cond_sim in range(self.env.n_exp_conds): |
|
|
175 |
|
|
|
176 |
### TRACKING VARIABLES ### |
|
|
177 |
episode_reward = 0 |
|
|
178 |
episode_steps = 0 |
|
|
179 |
|
|
|
180 |
emg_cond = [] |
|
|
181 |
kin_mb_cond = [] |
|
|
182 |
kin_mt_cond = [] |
|
|
183 |
rnn_activity_cond = [] |
|
|
184 |
rnn_input_cond = [] |
|
|
185 |
rnn_input_fp_cond = [] |
|
|
186 |
|
|
|
187 |
done = False |
|
|
188 |
|
|
|
189 |
### GET INITAL STATE + RESET MODEL BY POSE |
|
|
190 |
cond_to_select = i_cond_sim % self.env.n_exp_conds |
|
|
191 |
state = self.env.reset(cond_to_select) |
|
|
192 |
|
|
|
193 |
if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim: |
|
|
194 |
state = [*state, 0] |
|
|
195 |
else: |
|
|
196 |
state = [*state, self.env.condition_scalar] |
|
|
197 |
|
|
|
198 |
# Num_layers specified in the policy model |
|
|
199 |
h_prev = torch.zeros(size=(1, 1, self.hidden_size)) |
|
|
200 |
|
|
|
201 |
### STEPS PER EPISODE ### |
|
|
202 |
for timestep in range(self.env.timestep_limit): |
|
|
203 |
|
|
|
204 |
### SELECT ACTION ### |
|
|
205 |
with torch.no_grad(): |
|
|
206 |
|
|
|
207 |
if self.mode_to_sim in ["neural_pert"]: |
|
|
208 |
state = torch.FloatTensor(state).to(self.agent.device).unsqueeze(0).unsqueeze(0) |
|
|
209 |
h_prev = h_prev.to(self.agent.device) |
|
|
210 |
neural_pert = perturbation_specs.neural_pert[timestep % perturbation_specs.neural_pert.shape[0], :] |
|
|
211 |
neural_pert = torch.FloatTensor(neural_pert).to(self.agent.device).unsqueeze(0).unsqueeze(0) |
|
|
212 |
action, h_current, rnn_act, rnn_in = self.agent.actor.forward_for_neural_pert(state, h_prev, neural_pert) |
|
|
213 |
|
|
|
214 |
elif self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim: |
|
|
215 |
h_prev = h_prev*0 |
|
|
216 |
action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True) |
|
|
217 |
else: |
|
|
218 |
action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True) |
|
|
219 |
|
|
|
220 |
|
|
|
221 |
emg_cond.append(action) #[n_muscles, ] |
|
|
222 |
rnn_activity_cond.append(rnn_act[0, :]) # [1, n_hidden_units] --> [n_hidden_units,] |
|
|
223 |
rnn_input_cond.append(state) #[n_inputs, ] |
|
|
224 |
rnn_input_fp_cond.append(rnn_in[0, 0, :]) #[1, 1, n_hidden_units] --> [n_hidden_units, ] |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
### TRACKING REWARD + EXPERIENCE TUPLE### |
|
|
228 |
next_state, reward, done, _ = self.env.step(action) |
|
|
229 |
|
|
|
230 |
if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim: |
|
|
231 |
next_state = [*next_state, 0] |
|
|
232 |
else: |
|
|
233 |
next_state = [*next_state, self.env.condition_scalar] |
|
|
234 |
|
|
|
235 |
episode_reward += reward |
|
|
236 |
|
|
|
237 |
#now append the kinematics of the musculo body and the corresponding target |
|
|
238 |
kin_mb_t = [] |
|
|
239 |
kin_mt_t = [] |
|
|
240 |
for musculo_body in kinematics_preprocessing_specs.musculo_tracking: |
|
|
241 |
kin_mb_t.append(self.env.sim.data.get_body_xpos(musculo_body[0]).copy()) #[3, ] |
|
|
242 |
kin_mt_t.append(self.env.sim.data.get_body_xpos(musculo_body[1]).copy()) #[3, ] |
|
|
243 |
|
|
|
244 |
kin_mb_cond.append(kin_mb_t) # kin_mb_t : [n_targets, 3] |
|
|
245 |
kin_mt_cond.append(kin_mt_t) # kin_mt_t : [n_targets, 3] |
|
|
246 |
|
|
|
247 |
### VISUALIZE MODEL ### |
|
|
248 |
if self.visualize == True: |
|
|
249 |
self.env.render() |
|
|
250 |
|
|
|
251 |
state = next_state |
|
|
252 |
h_prev = h_current |
|
|
253 |
|
|
|
254 |
#Append the testing data |
|
|
255 |
emg[i_cond_sim] = np.array(emg_cond) # [timepoints, muscles] |
|
|
256 |
rnn_activity[i_cond_sim] = np.array(rnn_activity_cond) # [timepoints, n_hidden_units] |
|
|
257 |
rnn_input[i_cond_sim] = np.array(rnn_input_cond) #[timepoints, n_inputs] |
|
|
258 |
rnn_input_fp[i_cond_sim] = np.array(rnn_input_fp_cond) #[timepoints, n_hidden_units] |
|
|
259 |
kin_mb[i_cond_sim] = np.array(kin_mb_cond).transpose(1, 0, 2) # kin_mb_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3] |
|
|
260 |
kin_mt[i_cond_sim] = np.array(kin_mt_cond).transpose(1, 0, 2) # kin_mt_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3] |
|
|
261 |
|
|
|
262 |
##Append the jpca data |
|
|
263 |
activity_jpca.append(dict(A = rnn_activity_cond)) |
|
|
264 |
times_jpca.append(dict(times = np.arange(timestep))) #the timestep is assumed to be 1ms |
|
|
265 |
condition_tpoints_jpca.append(self.env.kin_to_sim[self.env.current_cond_to_sim].shape[-1]) |
|
|
266 |
n_fixedsteps_jpca.append(self.env.n_fixedsteps) |
|
|
267 |
|
|
|
268 |
### SAVE TESTING STATS ### |
|
|
269 |
Test_Data["emg"] = emg |
|
|
270 |
Test_Data["rnn_activity"] = rnn_activity |
|
|
271 |
Test_Data["rnn_input"] = rnn_input |
|
|
272 |
Test_Data["rnn_input_fp"] = rnn_input_fp |
|
|
273 |
Test_Data["kinematics_mbodies"] = kin_mb |
|
|
274 |
Test_Data["kinematics_mtargets"] = kin_mt |
|
|
275 |
|
|
|
276 |
### Save the jPCA data |
|
|
277 |
Data_jpca = fromarrays([activity_jpca, times_jpca], names=['A', 'times']) |
|
|
278 |
|
|
|
279 |
#save test data |
|
|
280 |
with open(save_name + '/test_data.pkl', 'wb') as f: |
|
|
281 |
pickle.dump(Test_Data, f) |
|
|
282 |
|
|
|
283 |
#save jpca data |
|
|
284 |
savemat(save_name + '/Data_jpca.mat', {'Data' : Data_jpca}) |
|
|
285 |
savemat(save_name + '/n_fixedsteps_jpca.mat', {'n_fsteps' : n_fixedsteps_jpca}) |
|
|
286 |
savemat(save_name + '/condition_tpoints_jpca.mat', {'cond_tpoints': condition_tpoints_jpca}) |
|
|
287 |
|
|
|
288 |
|
|
|
289 |
def train(self): |
|
|
290 |
|
|
|
291 |
""" Train an RNN based SAC agent to follow kinematic trajectory |
|
|
292 |
""" |
|
|
293 |
|
|
|
294 |
#Load the saved networks from the last training |
|
|
295 |
if self.load_saved_nets_for_training: |
|
|
296 |
self.load_saved_nets_from_checkpoint(load_best= False) |
|
|
297 |
|
|
|
298 |
### TRAINING DATA DICTIONARY ### |
|
|
299 |
Statistics = { |
|
|
300 |
"rewards": [], |
|
|
301 |
"steps": [], |
|
|
302 |
"policy_loss": [], |
|
|
303 |
"critic_loss": [] |
|
|
304 |
} |
|
|
305 |
|
|
|
306 |
highest_reward = -float("inf") # used for storing highest reward throughout training |
|
|
307 |
|
|
|
308 |
#Average reward across conditions initialization |
|
|
309 |
cond_train_count= np.ones((self.env.n_exp_conds,)) |
|
|
310 |
cond_avg_reward = np.zeros((self.env.n_exp_conds,)) |
|
|
311 |
cond_cum_reward = np.zeros((self.env.n_exp_conds,)) |
|
|
312 |
cond_cum_count = np.zeros((self.env.n_exp_conds,)) |
|
|
313 |
|
|
|
314 |
### BEGIN TRAINING ### |
|
|
315 |
for episode in range(self.episodes): |
|
|
316 |
|
|
|
317 |
### Gather Episode Data Variables ### |
|
|
318 |
episode_reward = 0 # reward for single episode |
|
|
319 |
episode_steps = 0 # steps completed for single episode |
|
|
320 |
policy_loss_tracker = [] # stores policy loss throughout episode |
|
|
321 |
critic1_loss_tracker = [] # stores critic loss throughout episode |
|
|
322 |
done = False # determines if episode is terminated |
|
|
323 |
|
|
|
324 |
### GET INITAL STATE + RESET MODEL BY POSE |
|
|
325 |
if self.condition_selection_strategy != "reward": |
|
|
326 |
cond_to_select = episode % self.env.n_exp_conds |
|
|
327 |
state = self.env.reset(cond_to_select) |
|
|
328 |
|
|
|
329 |
else: |
|
|
330 |
|
|
|
331 |
cond_indx = np.nonzero(cond_train_count>0)[0][0] |
|
|
332 |
cond_train_count[cond_indx] = cond_train_count[cond_indx] - 1 |
|
|
333 |
state = self.env.reset(cond_indx) |
|
|
334 |
|
|
|
335 |
#Append the high-level task scalar signal |
|
|
336 |
state = [*state, self.env.condition_scalar] |
|
|
337 |
|
|
|
338 |
ep_trajectory = [] # used to store (s_t, a_t, r_t, s_t+1) tuple for replay storage |
|
|
339 |
|
|
|
340 |
h_prev = torch.zeros(size=(1, 1, self.hidden_size)) # num_layers specified in the policy model |
|
|
341 |
|
|
|
342 |
### LOOP THROUGH EPISODE TIMESTEPS ### |
|
|
343 |
while not(done): |
|
|
344 |
|
|
|
345 |
### SELECT ACTION ### |
|
|
346 |
with torch.no_grad(): |
|
|
347 |
action, h_current, _, _ = self.agent.select_action(state, h_prev, evaluate=False) |
|
|
348 |
|
|
|
349 |
#Now query the neural activity idx from the simulator |
|
|
350 |
na_idx= self.env.coord_idx |
|
|
351 |
|
|
|
352 |
### UPDATE MODEL PARAMETERS ### |
|
|
353 |
if len(self.policy_memory.buffer) > self.policy_batch_size: |
|
|
354 |
for _ in range(self.batch_iters): |
|
|
355 |
critic_1_loss, critic_2_loss, policy_loss = self.agent.update_parameters(self.policy_memory, self.policy_batch_size) |
|
|
356 |
### STORE LOSSES ### |
|
|
357 |
policy_loss_tracker.append(policy_loss) |
|
|
358 |
critic1_loss_tracker.append(critic_1_loss) |
|
|
359 |
|
|
|
360 |
### SIMULATION ### |
|
|
361 |
reward = 0 |
|
|
362 |
for _ in range(self.env.frame_repeat): |
|
|
363 |
next_state, inter_reward, done, _ = self.env.step(action) |
|
|
364 |
next_state = [*next_state, self.env.condition_scalar] |
|
|
365 |
|
|
|
366 |
reward += inter_reward |
|
|
367 |
episode_steps += 1 |
|
|
368 |
|
|
|
369 |
### EARLY TERMINATION OF EPISODE ### |
|
|
370 |
if done: |
|
|
371 |
break |
|
|
372 |
|
|
|
373 |
episode_reward += reward |
|
|
374 |
|
|
|
375 |
### VISUALIZE MODEL ### |
|
|
376 |
if self.visualize == True: |
|
|
377 |
self.env.render() |
|
|
378 |
|
|
|
379 |
mask = 1 if episode_steps == self.env._max_episode_steps else float(not done) # ensure mask is not 0 if episode ends |
|
|
380 |
|
|
|
381 |
### STORE CURRENT TIMESTEP TUPLE ### |
|
|
382 |
if not self.env.nusim_data_exists: |
|
|
383 |
self.env.neural_activity[na_idx] = 0 |
|
|
384 |
|
|
|
385 |
ep_trajectory.append((state, |
|
|
386 |
action, |
|
|
387 |
reward, |
|
|
388 |
next_state, |
|
|
389 |
mask, |
|
|
390 |
h_current.squeeze(0).cpu().numpy(), |
|
|
391 |
self.env.neural_activity[na_idx], |
|
|
392 |
np.array([na_idx]))) |
|
|
393 |
|
|
|
394 |
### MOVE TO NEXT STATE ### |
|
|
395 |
state = next_state |
|
|
396 |
h_prev = h_current |
|
|
397 |
|
|
|
398 |
|
|
|
399 |
### PUSH TO REPLAY ### |
|
|
400 |
self.policy_memory.push(ep_trajectory) |
|
|
401 |
|
|
|
402 |
if self.condition_selection_strategy == "reward": |
|
|
403 |
|
|
|
404 |
cond_cum_reward[cond_indx] = cond_cum_reward[cond_indx] + episode_reward |
|
|
405 |
cond_cum_count[cond_indx] = cond_cum_count[cond_indx] + 1 |
|
|
406 |
|
|
|
407 |
#Check if there are all zeros in the cond_train_count array |
|
|
408 |
if np.all((cond_train_count == 0)): |
|
|
409 |
cond_avg_reward = cond_cum_reward / cond_cum_count |
|
|
410 |
cond_train_count = np.ceil((np.max(cond_avg_reward)*np.ones((self.env.n_exp_conds,)))/cond_avg_reward) |
|
|
411 |
|
|
|
412 |
### TRACKING ### |
|
|
413 |
Statistics["rewards"].append(episode_reward) |
|
|
414 |
Statistics["steps"].append(episode_steps) |
|
|
415 |
Statistics["policy_loss"].append(np.mean(np.array(policy_loss_tracker))) |
|
|
416 |
Statistics["critic_loss"].append(np.mean(np.array(critic1_loss_tracker))) |
|
|
417 |
|
|
|
418 |
### SAVE DATA TO FILE (in root project folder) ### |
|
|
419 |
if len(self.statistics_folder) != 0: |
|
|
420 |
np.save(self.statistics_folder + f'/stats_rewards.npy', Statistics['rewards']) |
|
|
421 |
np.save(self.statistics_folder + f'/stats_steps.npy', Statistics['steps']) |
|
|
422 |
np.save(self.statistics_folder + f'/stats_policy_loss.npy', Statistics['policy_loss']) |
|
|
423 |
np.save(self.statistics_folder + f'/stats_critic_loss.npy', Statistics['critic_loss']) |
|
|
424 |
|
|
|
425 |
|
|
|
426 |
### SAVING STATE DICT OF TRAINING ### |
|
|
427 |
if len(self.checkpoint_folder) != 0 and len(self.checkpoint_file) != 0: |
|
|
428 |
if episode % self.save_iter == 0 and len(self.policy_memory.buffer) > self.policy_batch_size: |
|
|
429 |
|
|
|
430 |
#Save the state dicts |
|
|
431 |
torch.save({ |
|
|
432 |
'iteration': episode, |
|
|
433 |
'agent_state_dict': self.agent.actor.state_dict(), |
|
|
434 |
'critic_state_dict': self.agent.critic.state_dict(), |
|
|
435 |
'critic_target_state_dict': self.agent.critic_target.state_dict(), |
|
|
436 |
'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(), |
|
|
437 |
'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(), |
|
|
438 |
}, self.checkpoint_folder + f'/{self.checkpoint_file}.pth') |
|
|
439 |
|
|
|
440 |
#Save the pickled model for fixedpoint finder analysis |
|
|
441 |
torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_fpf.pth') |
|
|
442 |
|
|
|
443 |
|
|
|
444 |
if episode_reward > highest_reward: |
|
|
445 |
torch.save({ |
|
|
446 |
'iteration': episode, |
|
|
447 |
'agent_state_dict': self.agent.actor.state_dict(), |
|
|
448 |
'critic_state_dict': self.agent.critic.state_dict(), |
|
|
449 |
'critic_target_state_dict': self.agent.critic_target.state_dict(), |
|
|
450 |
'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(), |
|
|
451 |
'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(), |
|
|
452 |
}, self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth') |
|
|
453 |
|
|
|
454 |
#Save the pickled model for fixedpoint finder analysis |
|
|
455 |
torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_best_fpf.pth') |
|
|
456 |
|
|
|
457 |
if episode_reward > highest_reward: |
|
|
458 |
highest_reward = episode_reward |
|
|
459 |
|
|
|
460 |
### PRINT TRAINING OUTPUT ### |
|
|
461 |
if self.verbose_training: |
|
|
462 |
print('-----------------------------------') |
|
|
463 |
print('highest reward: {} | reward: {} | timesteps completed: {}'.format(highest_reward, episode_reward, episode_steps)) |
|
|
464 |
print('-----------------------------------\n') |
|
|
465 |
|
|
|
466 |
def load_saved_nets_from_checkpoint(self, load_best: bool): |
|
|
467 |
|
|
|
468 |
#Load the saved networks from the checkpoint file |
|
|
469 |
#Saved networks include policy, critic, critic_target, policy_optimizer and critic_optimizer |
|
|
470 |
|
|
|
471 |
if not load_best: |
|
|
472 |
|
|
|
473 |
#Load the policy network |
|
|
474 |
self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_state_dict']) |
|
|
475 |
|
|
|
476 |
#Load the critic network |
|
|
477 |
self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_state_dict']) |
|
|
478 |
|
|
|
479 |
#Load the critic target network |
|
|
480 |
self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_target_state_dict']) |
|
|
481 |
|
|
|
482 |
#Load the policy optimizer |
|
|
483 |
self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_optimizer_state_dict']) |
|
|
484 |
|
|
|
485 |
#Load the critic optimizer |
|
|
486 |
self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_optimizer_state_dict']) |
|
|
487 |
|
|
|
488 |
else: |
|
|
489 |
|
|
|
490 |
#Load the policy network |
|
|
491 |
self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_state_dict']) |
|
|
492 |
|
|
|
493 |
#Load the critic network |
|
|
494 |
self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_state_dict']) |
|
|
495 |
|
|
|
496 |
#Load the critic target network |
|
|
497 |
self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_target_state_dict']) |
|
|
498 |
|
|
|
499 |
#Load the policy optimizer |
|
|
500 |
self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_optimizer_state_dict']) |
|
|
501 |
|
|
|
502 |
#Load the critic optimizer |
|
|
503 |
self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_optimizer_state_dict']) |