--- a +++ b/simulate.py @@ -0,0 +1,503 @@ +import numpy as np +import torch +from SAC.sac import SAC_Agent +from SAC.replay_memory import PolicyReplayMemory +from SAC.RL_Framework_Mujoco import Muscle_Env +from SAC import sensory_feedback_specs, kinematics_preprocessing_specs, perturbation_specs +import pickle +import os +from numpy.core.records import fromarrays +from scipy.io import savemat + + +#Set the current working directory +os.chdir(os.getcwd()) + +class Simulate(): + + def __init__(self, env:Muscle_Env, args): + + """Train a soft actor critic agent to control a musculoskeletal model to follow a kinematic trajectory. + + Parameters + ---------- + env: str + specify which environment (model) to train on + model: str + specify whether to use an rnn or a gated rnn (gru) + optimizer_spec: float + gamma discount parameter in sac algorithm + tau: float + tau parameter for soft updates of the critic and critic target networks + lr: float + learning rate of the critic and actor networks + alpha: float + entropy term used in the sac policy update + automatic entropy tuning: bool + whether to use automatic entropy tuning during training + seed: int + seed for the environment and networks + policy_batch_size: int + the number of episodes used in a batch during training + hidden_size: int + number of hidden neurons in the actor and critic + policy_replay_size: int + number of episodes to store in the replay + multi_policy_loss: bool + use additional policy losses during updates (reduce norm of weights) + ONLY USE WITH RNN, NOT IMPLEMENTED WITH GATING + batch_iters: int + How many experience replay rounds (not steps!) to perform between + each update + cuda: bool + use cuda gpu + visualize: bool + visualize model + model_save_name: str + specify the name of the model to save + episodes: int + total number of episodes to run + save_iter: int + number of iterations before saving the model + checkpoint_path: str + specify path to save and load model + muscle_path: str + path for the musculoskeletal model + muscle_params_path: str + path for musculoskeletal model parameters + """ + + ### TRAINING VARIABLES ### + self.episodes = args.total_episodes + self.hidden_size = args.hidden_size + self.policy_batch_size = args.policy_batch_size + self.visualize = args.visualize + self.root_dir = args.root_dir + self.checkpoint_file = args.checkpoint_file + self.checkpoint_folder = args.checkpoint_folder + self.statistics_folder = args.statistics_folder + self.batch_iters = args.batch_iters + self.save_iter = args.save_iter + self.mode_to_sim = args.mode + self.condition_selection_strategy = args.condition_selection_strategy + self.load_saved_nets_for_training = args.load_saved_nets_for_training + self.verbose_training = args.verbose_training + + ### ENSURE SAVING FILES ARE ACCURATE ### + assert isinstance(self.root_dir, str) + assert isinstance(self.checkpoint_folder, str) + assert isinstance(self.checkpoint_file, str) + + ### SEED ### + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + + ### LOAD CUSTOM GYM ENVIRONMENT ### + if self.mode_to_sim in ["musculo_properties"]: + self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets_pert.xml', 1, args) + + else: + self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets.xml', 1, args) + + self.observation_shape = self.env.observation_space.shape[0]+len(self.env.sfs_visual_velocity)*3+1 + + ### SAC AGENT ### + self.agent = SAC_Agent(self.observation_shape, + self.env.action_space, + args.hidden_size, + args.lr, + args.gamma, + args.tau, + args.alpha, + args.automatic_entropy_tuning, + args.model, + args.multi_policy_loss, + args.alpha_usim, + args.beta_usim, + args.gamma_usim, + args.zeta_nusim, + args.cuda) + + ### REPLAY MEMORY ### + self.policy_memory = PolicyReplayMemory(args.policy_replay_size, args.seed) + + + def test(self, save_name): + + """ Use a saved model to generate kinematic trajectory + + saves these values: + -------- + episode_reward: int + - Total reward of the trained model + kinematics: list + - Kinematics of the model during testing + hidden_activity: list + - list containing the rnn activity during testing + """ + + #Update the environment kinematics to both the training and testing conditions + self.env.update_kinematics_for_test() + + ### LOAD SAVED MODEL ### + self.load_saved_nets_from_checkpoint(load_best = True) + + #Set the recurrent connections to zero if the mode is SFE + if self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim: + self.agent.actor.rnn.weight_hh_l0 = torch.nn.Parameter(self.agent.actor.rnn.weight_hh_l0 * 0) + + ### TESTING DATA ### + Test_Data = { + "emg": {}, + "rnn_activity": {}, + "rnn_input": {}, + "rnn_input_fp": {}, + "kinematics_mbodies": {}, + "kinematics_mtargets": {} + } + + ###Data jPCA + activity_jpca = [] + times_jpca = [] + n_fixedsteps_jpca = [] + condition_tpoints_jpca = [] + + #Save the following after testing: EMG, RNN_activity, RNN_input, kin_musculo_bodies, kin_musculo_targets + emg = {} + rnn_activity = {} + rnn_input = {} + kin_mb = {} + kin_mt = {} + rnn_input_fp = {} + + for i_cond_sim in range(self.env.n_exp_conds): + + ### TRACKING VARIABLES ### + episode_reward = 0 + episode_steps = 0 + + emg_cond = [] + kin_mb_cond = [] + kin_mt_cond = [] + rnn_activity_cond = [] + rnn_input_cond = [] + rnn_input_fp_cond = [] + + done = False + + ### GET INITAL STATE + RESET MODEL BY POSE + cond_to_select = i_cond_sim % self.env.n_exp_conds + state = self.env.reset(cond_to_select) + + if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim: + state = [*state, 0] + else: + state = [*state, self.env.condition_scalar] + + # Num_layers specified in the policy model + h_prev = torch.zeros(size=(1, 1, self.hidden_size)) + + ### STEPS PER EPISODE ### + for timestep in range(self.env.timestep_limit): + + ### SELECT ACTION ### + with torch.no_grad(): + + if self.mode_to_sim in ["neural_pert"]: + state = torch.FloatTensor(state).to(self.agent.device).unsqueeze(0).unsqueeze(0) + h_prev = h_prev.to(self.agent.device) + neural_pert = perturbation_specs.neural_pert[timestep % perturbation_specs.neural_pert.shape[0], :] + neural_pert = torch.FloatTensor(neural_pert).to(self.agent.device).unsqueeze(0).unsqueeze(0) + action, h_current, rnn_act, rnn_in = self.agent.actor.forward_for_neural_pert(state, h_prev, neural_pert) + + elif self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim: + h_prev = h_prev*0 + action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True) + else: + action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True) + + + emg_cond.append(action) #[n_muscles, ] + rnn_activity_cond.append(rnn_act[0, :]) # [1, n_hidden_units] --> [n_hidden_units,] + rnn_input_cond.append(state) #[n_inputs, ] + rnn_input_fp_cond.append(rnn_in[0, 0, :]) #[1, 1, n_hidden_units] --> [n_hidden_units, ] + + + ### TRACKING REWARD + EXPERIENCE TUPLE### + next_state, reward, done, _ = self.env.step(action) + + if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim: + next_state = [*next_state, 0] + else: + next_state = [*next_state, self.env.condition_scalar] + + episode_reward += reward + + #now append the kinematics of the musculo body and the corresponding target + kin_mb_t = [] + kin_mt_t = [] + for musculo_body in kinematics_preprocessing_specs.musculo_tracking: + kin_mb_t.append(self.env.sim.data.get_body_xpos(musculo_body[0]).copy()) #[3, ] + kin_mt_t.append(self.env.sim.data.get_body_xpos(musculo_body[1]).copy()) #[3, ] + + kin_mb_cond.append(kin_mb_t) # kin_mb_t : [n_targets, 3] + kin_mt_cond.append(kin_mt_t) # kin_mt_t : [n_targets, 3] + + ### VISUALIZE MODEL ### + if self.visualize == True: + self.env.render() + + state = next_state + h_prev = h_current + + #Append the testing data + emg[i_cond_sim] = np.array(emg_cond) # [timepoints, muscles] + rnn_activity[i_cond_sim] = np.array(rnn_activity_cond) # [timepoints, n_hidden_units] + rnn_input[i_cond_sim] = np.array(rnn_input_cond) #[timepoints, n_inputs] + rnn_input_fp[i_cond_sim] = np.array(rnn_input_fp_cond) #[timepoints, n_hidden_units] + kin_mb[i_cond_sim] = np.array(kin_mb_cond).transpose(1, 0, 2) # kin_mb_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3] + kin_mt[i_cond_sim] = np.array(kin_mt_cond).transpose(1, 0, 2) # kin_mt_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3] + + ##Append the jpca data + activity_jpca.append(dict(A = rnn_activity_cond)) + times_jpca.append(dict(times = np.arange(timestep))) #the timestep is assumed to be 1ms + condition_tpoints_jpca.append(self.env.kin_to_sim[self.env.current_cond_to_sim].shape[-1]) + n_fixedsteps_jpca.append(self.env.n_fixedsteps) + + ### SAVE TESTING STATS ### + Test_Data["emg"] = emg + Test_Data["rnn_activity"] = rnn_activity + Test_Data["rnn_input"] = rnn_input + Test_Data["rnn_input_fp"] = rnn_input_fp + Test_Data["kinematics_mbodies"] = kin_mb + Test_Data["kinematics_mtargets"] = kin_mt + + ### Save the jPCA data + Data_jpca = fromarrays([activity_jpca, times_jpca], names=['A', 'times']) + + #save test data + with open(save_name + '/test_data.pkl', 'wb') as f: + pickle.dump(Test_Data, f) + + #save jpca data + savemat(save_name + '/Data_jpca.mat', {'Data' : Data_jpca}) + savemat(save_name + '/n_fixedsteps_jpca.mat', {'n_fsteps' : n_fixedsteps_jpca}) + savemat(save_name + '/condition_tpoints_jpca.mat', {'cond_tpoints': condition_tpoints_jpca}) + + + def train(self): + + """ Train an RNN based SAC agent to follow kinematic trajectory + """ + + #Load the saved networks from the last training + if self.load_saved_nets_for_training: + self.load_saved_nets_from_checkpoint(load_best= False) + + ### TRAINING DATA DICTIONARY ### + Statistics = { + "rewards": [], + "steps": [], + "policy_loss": [], + "critic_loss": [] + } + + highest_reward = -float("inf") # used for storing highest reward throughout training + + #Average reward across conditions initialization + cond_train_count= np.ones((self.env.n_exp_conds,)) + cond_avg_reward = np.zeros((self.env.n_exp_conds,)) + cond_cum_reward = np.zeros((self.env.n_exp_conds,)) + cond_cum_count = np.zeros((self.env.n_exp_conds,)) + + ### BEGIN TRAINING ### + for episode in range(self.episodes): + + ### Gather Episode Data Variables ### + episode_reward = 0 # reward for single episode + episode_steps = 0 # steps completed for single episode + policy_loss_tracker = [] # stores policy loss throughout episode + critic1_loss_tracker = [] # stores critic loss throughout episode + done = False # determines if episode is terminated + + ### GET INITAL STATE + RESET MODEL BY POSE + if self.condition_selection_strategy != "reward": + cond_to_select = episode % self.env.n_exp_conds + state = self.env.reset(cond_to_select) + + else: + + cond_indx = np.nonzero(cond_train_count>0)[0][0] + cond_train_count[cond_indx] = cond_train_count[cond_indx] - 1 + state = self.env.reset(cond_indx) + + #Append the high-level task scalar signal + state = [*state, self.env.condition_scalar] + + ep_trajectory = [] # used to store (s_t, a_t, r_t, s_t+1) tuple for replay storage + + h_prev = torch.zeros(size=(1, 1, self.hidden_size)) # num_layers specified in the policy model + + ### LOOP THROUGH EPISODE TIMESTEPS ### + while not(done): + + ### SELECT ACTION ### + with torch.no_grad(): + action, h_current, _, _ = self.agent.select_action(state, h_prev, evaluate=False) + + #Now query the neural activity idx from the simulator + na_idx= self.env.coord_idx + + ### UPDATE MODEL PARAMETERS ### + if len(self.policy_memory.buffer) > self.policy_batch_size: + for _ in range(self.batch_iters): + critic_1_loss, critic_2_loss, policy_loss = self.agent.update_parameters(self.policy_memory, self.policy_batch_size) + ### STORE LOSSES ### + policy_loss_tracker.append(policy_loss) + critic1_loss_tracker.append(critic_1_loss) + + ### SIMULATION ### + reward = 0 + for _ in range(self.env.frame_repeat): + next_state, inter_reward, done, _ = self.env.step(action) + next_state = [*next_state, self.env.condition_scalar] + + reward += inter_reward + episode_steps += 1 + + ### EARLY TERMINATION OF EPISODE ### + if done: + break + + episode_reward += reward + + ### VISUALIZE MODEL ### + if self.visualize == True: + self.env.render() + + mask = 1 if episode_steps == self.env._max_episode_steps else float(not done) # ensure mask is not 0 if episode ends + + ### STORE CURRENT TIMESTEP TUPLE ### + if not self.env.nusim_data_exists: + self.env.neural_activity[na_idx] = 0 + + ep_trajectory.append((state, + action, + reward, + next_state, + mask, + h_current.squeeze(0).cpu().numpy(), + self.env.neural_activity[na_idx], + np.array([na_idx]))) + + ### MOVE TO NEXT STATE ### + state = next_state + h_prev = h_current + + + ### PUSH TO REPLAY ### + self.policy_memory.push(ep_trajectory) + + if self.condition_selection_strategy == "reward": + + cond_cum_reward[cond_indx] = cond_cum_reward[cond_indx] + episode_reward + cond_cum_count[cond_indx] = cond_cum_count[cond_indx] + 1 + + #Check if there are all zeros in the cond_train_count array + if np.all((cond_train_count == 0)): + cond_avg_reward = cond_cum_reward / cond_cum_count + cond_train_count = np.ceil((np.max(cond_avg_reward)*np.ones((self.env.n_exp_conds,)))/cond_avg_reward) + + ### TRACKING ### + Statistics["rewards"].append(episode_reward) + Statistics["steps"].append(episode_steps) + Statistics["policy_loss"].append(np.mean(np.array(policy_loss_tracker))) + Statistics["critic_loss"].append(np.mean(np.array(critic1_loss_tracker))) + + ### SAVE DATA TO FILE (in root project folder) ### + if len(self.statistics_folder) != 0: + np.save(self.statistics_folder + f'/stats_rewards.npy', Statistics['rewards']) + np.save(self.statistics_folder + f'/stats_steps.npy', Statistics['steps']) + np.save(self.statistics_folder + f'/stats_policy_loss.npy', Statistics['policy_loss']) + np.save(self.statistics_folder + f'/stats_critic_loss.npy', Statistics['critic_loss']) + + + ### SAVING STATE DICT OF TRAINING ### + if len(self.checkpoint_folder) != 0 and len(self.checkpoint_file) != 0: + if episode % self.save_iter == 0 and len(self.policy_memory.buffer) > self.policy_batch_size: + + #Save the state dicts + torch.save({ + 'iteration': episode, + 'agent_state_dict': self.agent.actor.state_dict(), + 'critic_state_dict': self.agent.critic.state_dict(), + 'critic_target_state_dict': self.agent.critic_target.state_dict(), + 'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(), + 'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(), + }, self.checkpoint_folder + f'/{self.checkpoint_file}.pth') + + #Save the pickled model for fixedpoint finder analysis + torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_fpf.pth') + + + if episode_reward > highest_reward: + torch.save({ + 'iteration': episode, + 'agent_state_dict': self.agent.actor.state_dict(), + 'critic_state_dict': self.agent.critic.state_dict(), + 'critic_target_state_dict': self.agent.critic_target.state_dict(), + 'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(), + 'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(), + }, self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth') + + #Save the pickled model for fixedpoint finder analysis + torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_best_fpf.pth') + + if episode_reward > highest_reward: + highest_reward = episode_reward + + ### PRINT TRAINING OUTPUT ### + if self.verbose_training: + print('-----------------------------------') + print('highest reward: {} | reward: {} | timesteps completed: {}'.format(highest_reward, episode_reward, episode_steps)) + print('-----------------------------------\n') + + def load_saved_nets_from_checkpoint(self, load_best: bool): + + #Load the saved networks from the checkpoint file + #Saved networks include policy, critic, critic_target, policy_optimizer and critic_optimizer + + if not load_best: + + #Load the policy network + self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_state_dict']) + + #Load the critic network + self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_state_dict']) + + #Load the critic target network + self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_target_state_dict']) + + #Load the policy optimizer + self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_optimizer_state_dict']) + + #Load the critic optimizer + self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_optimizer_state_dict']) + + else: + + #Load the policy network + self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_state_dict']) + + #Load the critic network + self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_state_dict']) + + #Load the critic target network + self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_target_state_dict']) + + #Load the policy optimizer + self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_optimizer_state_dict']) + + #Load the critic optimizer + self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_optimizer_state_dict'])