Diff of /simulate.py [000000] .. [9f010e]

Switch to unified view

a b/simulate.py
1
import numpy as np
2
import torch
3
from SAC.sac import SAC_Agent
4
from SAC.replay_memory import PolicyReplayMemory
5
from SAC.RL_Framework_Mujoco import Muscle_Env
6
from SAC import sensory_feedback_specs, kinematics_preprocessing_specs, perturbation_specs
7
import pickle
8
import os
9
from numpy.core.records import fromarrays
10
from scipy.io import savemat
11
12
13
#Set the current working directory
14
os.chdir(os.getcwd())
15
16
class Simulate():
17
18
    def __init__(self, env:Muscle_Env, args):
19
20
        """Train a soft actor critic agent to control a musculoskeletal model to follow a kinematic trajectory.
21
22
        Parameters
23
        ----------
24
        env: str
25
            specify which environment (model) to train on
26
        model: str
27
            specify whether to use an rnn or a gated rnn (gru)
28
        optimizer_spec: float
29
            gamma discount parameter in sac algorithm
30
        tau: float
31
            tau parameter for soft updates of the critic and critic target networks
32
        lr: float
33
            learning rate of the critic and actor networks
34
        alpha: float
35
            entropy term used in the sac policy update
36
        automatic entropy tuning: bool
37
            whether to use automatic entropy tuning during training
38
        seed: int
39
            seed for the environment and networks
40
        policy_batch_size: int
41
            the number of episodes used in a batch during training
42
        hidden_size: int
43
            number of hidden neurons in the actor and critic
44
        policy_replay_size: int
45
            number of episodes to store in the replay
46
        multi_policy_loss: bool
47
            use additional policy losses during updates (reduce norm of weights)
48
            ONLY USE WITH RNN, NOT IMPLEMENTED WITH GATING
49
        batch_iters: int
50
            How many experience replay rounds (not steps!) to perform between
51
            each update
52
        cuda: bool
53
            use cuda gpu
54
        visualize: bool
55
            visualize model
56
        model_save_name: str
57
            specify the name of the model to save
58
        episodes: int
59
            total number of episodes to run
60
        save_iter: int
61
            number of iterations before saving the model
62
        checkpoint_path: str
63
            specify path to save and load model
64
        muscle_path: str
65
            path for the musculoskeletal model
66
        muscle_params_path: str
67
            path for musculoskeletal model parameters
68
        """
69
70
        ### TRAINING VARIABLES ###
71
        self.episodes = args.total_episodes
72
        self.hidden_size = args.hidden_size
73
        self.policy_batch_size = args.policy_batch_size
74
        self.visualize = args.visualize
75
        self.root_dir = args.root_dir
76
        self.checkpoint_file = args.checkpoint_file
77
        self.checkpoint_folder = args.checkpoint_folder
78
        self.statistics_folder = args.statistics_folder
79
        self.batch_iters = args.batch_iters
80
        self.save_iter = args.save_iter
81
        self.mode_to_sim = args.mode
82
        self.condition_selection_strategy = args.condition_selection_strategy
83
        self.load_saved_nets_for_training = args.load_saved_nets_for_training
84
        self.verbose_training = args.verbose_training
85
86
        ### ENSURE SAVING FILES ARE ACCURATE ###
87
        assert isinstance(self.root_dir, str)
88
        assert isinstance(self.checkpoint_folder, str)
89
        assert isinstance(self.checkpoint_file, str)
90
91
        ### SEED ###
92
        torch.manual_seed(args.seed)
93
        np.random.seed(args.seed)
94
95
96
        ### LOAD CUSTOM GYM ENVIRONMENT ###
97
        if self.mode_to_sim in ["musculo_properties"]:
98
            self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets_pert.xml', 1, args)
99
100
        else:
101
            self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets.xml', 1, args)
102
103
        self.observation_shape = self.env.observation_space.shape[0]+len(self.env.sfs_visual_velocity)*3+1
104
105
        ### SAC AGENT ###
106
        self.agent = SAC_Agent(self.observation_shape, 
107
                               self.env.action_space, 
108
                               args.hidden_size, 
109
                               args.lr, 
110
                               args.gamma, 
111
                               args.tau, 
112
                               args.alpha, 
113
                               args.automatic_entropy_tuning, 
114
                               args.model,
115
                               args.multi_policy_loss,
116
                               args.alpha_usim,
117
                               args.beta_usim,
118
                               args.gamma_usim,
119
                               args.zeta_nusim,
120
                               args.cuda)
121
122
        ### REPLAY MEMORY ###
123
        self.policy_memory = PolicyReplayMemory(args.policy_replay_size, args.seed)
124
125
126
    def test(self, save_name):
127
128
        """ Use a saved model to generate kinematic trajectory
129
130
            saves these values: 
131
            --------
132
                episode_reward: int
133
                    - Total reward of the trained model
134
                kinematics: list
135
                    - Kinematics of the model during testing
136
                hidden_activity: list
137
                    - list containing the rnn activity during testing
138
        """
139
140
        #Update the environment kinematics to both the training and testing conditions
141
        self.env.update_kinematics_for_test()
142
143
        ### LOAD SAVED MODEL ###
144
        self.load_saved_nets_from_checkpoint(load_best = True)
145
146
        #Set the recurrent connections to zero if the mode is SFE
147
        if self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim:
148
            self.agent.actor.rnn.weight_hh_l0 = torch.nn.Parameter(self.agent.actor.rnn.weight_hh_l0 * 0)
149
150
        ### TESTING DATA ###
151
        Test_Data = {
152
            "emg": {},
153
            "rnn_activity": {},
154
            "rnn_input": {},
155
            "rnn_input_fp": {},
156
            "kinematics_mbodies": {},
157
            "kinematics_mtargets": {}
158
        }
159
160
        ###Data jPCA
161
        activity_jpca = []
162
        times_jpca = []
163
        n_fixedsteps_jpca = []
164
        condition_tpoints_jpca = []
165
166
        #Save the following after testing: EMG, RNN_activity, RNN_input, kin_musculo_bodies, kin_musculo_targets
167
        emg = {}
168
        rnn_activity = {}
169
        rnn_input = {}
170
        kin_mb = {}
171
        kin_mt = {}
172
        rnn_input_fp = {}
173
174
        for i_cond_sim in range(self.env.n_exp_conds):
175
176
            ### TRACKING VARIABLES ###
177
            episode_reward = 0
178
            episode_steps = 0 
179
180
            emg_cond = []
181
            kin_mb_cond = []
182
            kin_mt_cond = []
183
            rnn_activity_cond = []
184
            rnn_input_cond = []
185
            rnn_input_fp_cond = []
186
187
            done = False
188
189
            ### GET INITAL STATE + RESET MODEL BY POSE
190
            cond_to_select = i_cond_sim % self.env.n_exp_conds
191
            state = self.env.reset(cond_to_select)
192
193
            if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim:
194
                state = [*state, 0]
195
            else:
196
                state = [*state, self.env.condition_scalar]
197
198
            # Num_layers specified in the policy model 
199
            h_prev = torch.zeros(size=(1, 1, self.hidden_size))
200
201
            ### STEPS PER EPISODE ###
202
            for timestep in range(self.env.timestep_limit):
203
204
                ### SELECT ACTION ###
205
                with torch.no_grad():
206
207
                    if self.mode_to_sim in ["neural_pert"]:
208
                        state = torch.FloatTensor(state).to(self.agent.device).unsqueeze(0).unsqueeze(0)
209
                        h_prev = h_prev.to(self.agent.device)
210
                        neural_pert = perturbation_specs.neural_pert[timestep % perturbation_specs.neural_pert.shape[0], :]
211
                        neural_pert = torch.FloatTensor(neural_pert).to(self.agent.device).unsqueeze(0).unsqueeze(0) 
212
                        action, h_current, rnn_act, rnn_in = self.agent.actor.forward_for_neural_pert(state, h_prev, neural_pert)
213
214
                    elif self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim:
215
                        h_prev = h_prev*0
216
                        action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True)
217
                    else:
218
                        action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True)
219
220
                    
221
                    emg_cond.append(action) #[n_muscles, ]
222
                    rnn_activity_cond.append(rnn_act[0, :])  # [1, n_hidden_units] --> [n_hidden_units,]
223
                    rnn_input_cond.append(state) #[n_inputs, ]
224
                    rnn_input_fp_cond.append(rnn_in[0, 0, :])  #[1, 1, n_hidden_units] --> [n_hidden_units, ]
225
226
227
                ### TRACKING REWARD + EXPERIENCE TUPLE###
228
                next_state, reward, done, _ = self.env.step(action)
229
                
230
                if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim:
231
                    next_state = [*next_state, 0]
232
                else:
233
                    next_state = [*next_state, self.env.condition_scalar]
234
235
                episode_reward += reward
236
237
                #now append the kinematics of the musculo body and the corresponding target
238
                kin_mb_t = []
239
                kin_mt_t = []
240
                for musculo_body in kinematics_preprocessing_specs.musculo_tracking:
241
                    kin_mb_t.append(self.env.sim.data.get_body_xpos(musculo_body[0]).copy())   #[3, ]    
242
                    kin_mt_t.append(self.env.sim.data.get_body_xpos(musculo_body[1]).copy()) #[3, ]
243
244
                kin_mb_cond.append(kin_mb_t)   # kin_mb_t : [n_targets, 3]
245
                kin_mt_cond.append(kin_mt_t)   # kin_mt_t : [n_targets, 3]
246
247
                ### VISUALIZE MODEL ###
248
                if self.visualize == True:
249
                    self.env.render()
250
251
                state = next_state
252
                h_prev = h_current
253
            
254
            #Append the testing data
255
            emg[i_cond_sim] = np.array(emg_cond)  # [timepoints, muscles]
256
            rnn_activity[i_cond_sim] = np.array(rnn_activity_cond) # [timepoints, n_hidden_units]
257
            rnn_input[i_cond_sim] = np.array(rnn_input_cond)  #[timepoints, n_inputs]
258
            rnn_input_fp[i_cond_sim] = np.array(rnn_input_fp_cond)  #[timepoints, n_hidden_units]
259
            kin_mb[i_cond_sim] = np.array(kin_mb_cond).transpose(1, 0, 2)   # kin_mb_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3]
260
            kin_mt[i_cond_sim] = np.array(kin_mt_cond).transpose(1, 0, 2)   # kin_mt_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3]
261
262
            ##Append the jpca data
263
            activity_jpca.append(dict(A = rnn_activity_cond))
264
            times_jpca.append(dict(times = np.arange(timestep)))    #the timestep is assumed to be 1ms
265
            condition_tpoints_jpca.append(self.env.kin_to_sim[self.env.current_cond_to_sim].shape[-1])
266
            n_fixedsteps_jpca.append(self.env.n_fixedsteps)
267
268
        ### SAVE TESTING STATS ###
269
        Test_Data["emg"] = emg
270
        Test_Data["rnn_activity"] = rnn_activity
271
        Test_Data["rnn_input"] = rnn_input
272
        Test_Data["rnn_input_fp"] = rnn_input_fp
273
        Test_Data["kinematics_mbodies"] = kin_mb
274
        Test_Data["kinematics_mtargets"] = kin_mt
275
276
        ### Save the jPCA data
277
        Data_jpca = fromarrays([activity_jpca, times_jpca], names=['A', 'times'])
278
279
        #save test data
280
        with open(save_name + '/test_data.pkl', 'wb') as f:
281
            pickle.dump(Test_Data, f)
282
283
        #save jpca data
284
        savemat(save_name + '/Data_jpca.mat', {'Data' : Data_jpca})
285
        savemat(save_name + '/n_fixedsteps_jpca.mat', {'n_fsteps' : n_fixedsteps_jpca})
286
        savemat(save_name + '/condition_tpoints_jpca.mat', {'cond_tpoints': condition_tpoints_jpca})
287
        
288
289
    def train(self):
290
291
        """ Train an RNN based SAC agent to follow kinematic trajectory
292
        """
293
294
        #Load the saved networks from the last training
295
        if self.load_saved_nets_for_training:
296
            self.load_saved_nets_from_checkpoint(load_best= False)
297
298
        ### TRAINING DATA DICTIONARY ###
299
        Statistics = {
300
            "rewards": [],
301
            "steps": [],
302
            "policy_loss": [],
303
            "critic_loss": []
304
        }
305
306
        highest_reward = -float("inf") # used for storing highest reward throughout training
307
308
        #Average reward across conditions initialization
309
        cond_train_count= np.ones((self.env.n_exp_conds,))
310
        cond_avg_reward = np.zeros((self.env.n_exp_conds,))
311
        cond_cum_reward = np.zeros((self.env.n_exp_conds,))
312
        cond_cum_count = np.zeros((self.env.n_exp_conds,))
313
314
        ### BEGIN TRAINING ###
315
        for episode in range(self.episodes):
316
317
            ### Gather Episode Data Variables ###
318
            episode_reward = 0          # reward for single episode
319
            episode_steps = 0           # steps completed for single episode
320
            policy_loss_tracker = []    # stores policy loss throughout episode
321
            critic1_loss_tracker = []   # stores critic loss throughout episode
322
            done = False                # determines if episode is terminated
323
324
            ### GET INITAL STATE + RESET MODEL BY POSE
325
            if self.condition_selection_strategy != "reward":
326
                cond_to_select = episode % self.env.n_exp_conds
327
                state = self.env.reset(cond_to_select)
328
329
            else:
330
331
                cond_indx = np.nonzero(cond_train_count>0)[0][0]
332
                cond_train_count[cond_indx] = cond_train_count[cond_indx] - 1
333
                state = self.env.reset(cond_indx)
334
335
            #Append the high-level task scalar signal
336
            state = [*state, self.env.condition_scalar]
337
338
            ep_trajectory = []  # used to store (s_t, a_t, r_t, s_t+1) tuple for replay storage
339
340
            h_prev = torch.zeros(size=(1, 1, self.hidden_size)) # num_layers specified in the policy model
341
342
            ### LOOP THROUGH EPISODE TIMESTEPS ###
343
            while not(done):
344
345
                ### SELECT ACTION ###
346
                with torch.no_grad():
347
                    action, h_current, _, _ = self.agent.select_action(state, h_prev, evaluate=False)
348
                    
349
                    #Now query the neural activity idx from the simulator
350
                    na_idx= self.env.coord_idx
351
352
                ### UPDATE MODEL PARAMETERS ###
353
                if len(self.policy_memory.buffer) > self.policy_batch_size:
354
                    for _ in range(self.batch_iters):
355
                        critic_1_loss, critic_2_loss, policy_loss = self.agent.update_parameters(self.policy_memory, self.policy_batch_size)
356
                        ### STORE LOSSES ###
357
                        policy_loss_tracker.append(policy_loss)
358
                        critic1_loss_tracker.append(critic_1_loss)
359
360
                ### SIMULATION ###
361
                reward = 0
362
                for _ in range(self.env.frame_repeat):
363
                    next_state, inter_reward, done, _ = self.env.step(action)
364
                    next_state = [*next_state, self.env.condition_scalar]
365
366
                    reward += inter_reward
367
                    episode_steps += 1
368
369
                    ### EARLY TERMINATION OF EPISODE ###
370
                    if done:
371
                        break
372
373
                episode_reward += reward
374
375
                ### VISUALIZE MODEL ###
376
                if self.visualize == True:
377
                    self.env.render()
378
379
                mask = 1 if episode_steps == self.env._max_episode_steps else float(not done) # ensure mask is not 0 if episode ends
380
381
                ### STORE CURRENT TIMESTEP TUPLE ###
382
                if not self.env.nusim_data_exists:
383
                    self.env.neural_activity[na_idx] = 0
384
385
                ep_trajectory.append((state, 
386
                                        action, 
387
                                        reward, 
388
                                        next_state, 
389
                                        mask,
390
                                        h_current.squeeze(0).cpu().numpy(),
391
                                        self.env.neural_activity[na_idx],
392
                                        np.array([na_idx])))
393
394
                ### MOVE TO NEXT STATE ###
395
                state = next_state
396
                h_prev = h_current
397
398
            
399
            ### PUSH TO REPLAY ###
400
            self.policy_memory.push(ep_trajectory)
401
402
            if self.condition_selection_strategy == "reward":
403
404
                cond_cum_reward[cond_indx] = cond_cum_reward[cond_indx] + episode_reward
405
                cond_cum_count[cond_indx] = cond_cum_count[cond_indx] + 1
406
407
                #Check if there are all zeros in the cond_train_count array
408
                if np.all((cond_train_count == 0)):
409
                    cond_avg_reward = cond_cum_reward / cond_cum_count
410
                    cond_train_count = np.ceil((np.max(cond_avg_reward)*np.ones((self.env.n_exp_conds,)))/cond_avg_reward)
411
412
            ### TRACKING ###
413
            Statistics["rewards"].append(episode_reward)
414
            Statistics["steps"].append(episode_steps)
415
            Statistics["policy_loss"].append(np.mean(np.array(policy_loss_tracker)))
416
            Statistics["critic_loss"].append(np.mean(np.array(critic1_loss_tracker)))
417
418
            ### SAVE DATA TO FILE (in root project folder) ###
419
            if len(self.statistics_folder) != 0:
420
                np.save(self.statistics_folder + f'/stats_rewards.npy', Statistics['rewards'])
421
                np.save(self.statistics_folder + f'/stats_steps.npy', Statistics['steps'])
422
                np.save(self.statistics_folder + f'/stats_policy_loss.npy', Statistics['policy_loss'])
423
                np.save(self.statistics_folder + f'/stats_critic_loss.npy', Statistics['critic_loss'])
424
425
426
            ### SAVING STATE DICT OF TRAINING ###
427
            if len(self.checkpoint_folder) != 0 and len(self.checkpoint_file) != 0:
428
                if episode % self.save_iter == 0 and len(self.policy_memory.buffer) > self.policy_batch_size:
429
                    
430
                    #Save the state dicts
431
                    torch.save({
432
                         'iteration': episode,
433
                         'agent_state_dict': self.agent.actor.state_dict(),
434
                         'critic_state_dict': self.agent.critic.state_dict(),
435
                         'critic_target_state_dict': self.agent.critic_target.state_dict(),
436
                         'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(),
437
                         'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(),
438
                     }, self.checkpoint_folder + f'/{self.checkpoint_file}.pth')
439
440
                    #Save the pickled model for fixedpoint finder analysis
441
                    torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_fpf.pth')
442
443
                
444
                if episode_reward > highest_reward:
445
                    torch.save({
446
                            'iteration': episode,
447
                            'agent_state_dict': self.agent.actor.state_dict(),
448
                            'critic_state_dict': self.agent.critic.state_dict(),
449
                            'critic_target_state_dict': self.agent.critic_target.state_dict(),
450
                            'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(),
451
                            'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(),
452
                        }, self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')
453
454
                    #Save the pickled model for fixedpoint finder analysis
455
                    torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_best_fpf.pth')
456
457
            if episode_reward > highest_reward:
458
                highest_reward = episode_reward
459
460
            ### PRINT TRAINING OUTPUT ###
461
            if self.verbose_training:
462
                print('-----------------------------------')
463
                print('highest reward: {} | reward: {} | timesteps completed: {}'.format(highest_reward, episode_reward, episode_steps))
464
                print('-----------------------------------\n')
465
466
    def load_saved_nets_from_checkpoint(self, load_best: bool):
467
468
        #Load the saved networks from the checkpoint file
469
        #Saved networks include policy, critic, critic_target, policy_optimizer and critic_optimizer
470
471
        if not load_best:
472
473
            #Load the policy network
474
            self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_state_dict'])
475
476
            #Load the critic network
477
            self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_state_dict'])
478
479
            #Load the critic target network
480
            self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_target_state_dict'])
481
482
            #Load the policy optimizer 
483
            self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_optimizer_state_dict'])
484
485
            #Load the critic optimizer
486
            self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_optimizer_state_dict'])
487
488
        else:
489
490
            #Load the policy network
491
            self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_state_dict'])
492
493
            #Load the critic network
494
            self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_state_dict'])
495
496
            #Load the critic target network
497
            self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_target_state_dict'])
498
499
            #Load the policy optimizer 
500
            self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_optimizer_state_dict'])
501
502
            #Load the critic optimizer
503
            self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_optimizer_state_dict'])