--- a
+++ b/simulate.py
@@ -0,0 +1,503 @@
+import numpy as np
+import torch
+from SAC.sac import SAC_Agent
+from SAC.replay_memory import PolicyReplayMemory
+from SAC.RL_Framework_Mujoco import Muscle_Env
+from SAC import sensory_feedback_specs, kinematics_preprocessing_specs, perturbation_specs
+import pickle
+import os
+from numpy.core.records import fromarrays
+from scipy.io import savemat
+
+
+#Set the current working directory
+os.chdir(os.getcwd())
+
+class Simulate():
+
+    def __init__(self, env:Muscle_Env, args):
+
+        """Train a soft actor critic agent to control a musculoskeletal model to follow a kinematic trajectory.
+
+        Parameters
+        ----------
+        env: str
+            specify which environment (model) to train on
+        model: str
+            specify whether to use an rnn or a gated rnn (gru)
+        optimizer_spec: float
+            gamma discount parameter in sac algorithm
+        tau: float
+            tau parameter for soft updates of the critic and critic target networks
+        lr: float
+            learning rate of the critic and actor networks
+        alpha: float
+            entropy term used in the sac policy update
+        automatic entropy tuning: bool
+            whether to use automatic entropy tuning during training
+        seed: int
+            seed for the environment and networks
+        policy_batch_size: int
+            the number of episodes used in a batch during training
+        hidden_size: int
+            number of hidden neurons in the actor and critic
+        policy_replay_size: int
+            number of episodes to store in the replay
+        multi_policy_loss: bool
+            use additional policy losses during updates (reduce norm of weights)
+            ONLY USE WITH RNN, NOT IMPLEMENTED WITH GATING
+        batch_iters: int
+            How many experience replay rounds (not steps!) to perform between
+            each update
+        cuda: bool
+            use cuda gpu
+        visualize: bool
+            visualize model
+        model_save_name: str
+            specify the name of the model to save
+        episodes: int
+            total number of episodes to run
+        save_iter: int
+            number of iterations before saving the model
+        checkpoint_path: str
+            specify path to save and load model
+        muscle_path: str
+            path for the musculoskeletal model
+        muscle_params_path: str
+            path for musculoskeletal model parameters
+        """
+
+        ### TRAINING VARIABLES ###
+        self.episodes = args.total_episodes
+        self.hidden_size = args.hidden_size
+        self.policy_batch_size = args.policy_batch_size
+        self.visualize = args.visualize
+        self.root_dir = args.root_dir
+        self.checkpoint_file = args.checkpoint_file
+        self.checkpoint_folder = args.checkpoint_folder
+        self.statistics_folder = args.statistics_folder
+        self.batch_iters = args.batch_iters
+        self.save_iter = args.save_iter
+        self.mode_to_sim = args.mode
+        self.condition_selection_strategy = args.condition_selection_strategy
+        self.load_saved_nets_for_training = args.load_saved_nets_for_training
+        self.verbose_training = args.verbose_training
+
+        ### ENSURE SAVING FILES ARE ACCURATE ###
+        assert isinstance(self.root_dir, str)
+        assert isinstance(self.checkpoint_folder, str)
+        assert isinstance(self.checkpoint_file, str)
+
+        ### SEED ###
+        torch.manual_seed(args.seed)
+        np.random.seed(args.seed)
+
+
+        ### LOAD CUSTOM GYM ENVIRONMENT ###
+        if self.mode_to_sim in ["musculo_properties"]:
+            self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets_pert.xml', 1, args)
+
+        else:
+            self.env = env(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets.xml', 1, args)
+
+        self.observation_shape = self.env.observation_space.shape[0]+len(self.env.sfs_visual_velocity)*3+1
+
+        ### SAC AGENT ###
+        self.agent = SAC_Agent(self.observation_shape, 
+                               self.env.action_space, 
+                               args.hidden_size, 
+                               args.lr, 
+                               args.gamma, 
+                               args.tau, 
+                               args.alpha, 
+                               args.automatic_entropy_tuning, 
+                               args.model,
+                               args.multi_policy_loss,
+                               args.alpha_usim,
+                               args.beta_usim,
+                               args.gamma_usim,
+                               args.zeta_nusim,
+                               args.cuda)
+
+        ### REPLAY MEMORY ###
+        self.policy_memory = PolicyReplayMemory(args.policy_replay_size, args.seed)
+
+
+    def test(self, save_name):
+
+        """ Use a saved model to generate kinematic trajectory
+
+            saves these values: 
+            --------
+                episode_reward: int
+                    - Total reward of the trained model
+                kinematics: list
+                    - Kinematics of the model during testing
+                hidden_activity: list
+                    - list containing the rnn activity during testing
+        """
+
+        #Update the environment kinematics to both the training and testing conditions
+        self.env.update_kinematics_for_test()
+
+        ### LOAD SAVED MODEL ###
+        self.load_saved_nets_from_checkpoint(load_best = True)
+
+        #Set the recurrent connections to zero if the mode is SFE
+        if self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim:
+            self.agent.actor.rnn.weight_hh_l0 = torch.nn.Parameter(self.agent.actor.rnn.weight_hh_l0 * 0)
+
+        ### TESTING DATA ###
+        Test_Data = {
+            "emg": {},
+            "rnn_activity": {},
+            "rnn_input": {},
+            "rnn_input_fp": {},
+            "kinematics_mbodies": {},
+            "kinematics_mtargets": {}
+        }
+
+        ###Data jPCA
+        activity_jpca = []
+        times_jpca = []
+        n_fixedsteps_jpca = []
+        condition_tpoints_jpca = []
+
+        #Save the following after testing: EMG, RNN_activity, RNN_input, kin_musculo_bodies, kin_musculo_targets
+        emg = {}
+        rnn_activity = {}
+        rnn_input = {}
+        kin_mb = {}
+        kin_mt = {}
+        rnn_input_fp = {}
+
+        for i_cond_sim in range(self.env.n_exp_conds):
+
+            ### TRACKING VARIABLES ###
+            episode_reward = 0
+            episode_steps = 0 
+
+            emg_cond = []
+            kin_mb_cond = []
+            kin_mt_cond = []
+            rnn_activity_cond = []
+            rnn_input_cond = []
+            rnn_input_fp_cond = []
+
+            done = False
+
+            ### GET INITAL STATE + RESET MODEL BY POSE
+            cond_to_select = i_cond_sim % self.env.n_exp_conds
+            state = self.env.reset(cond_to_select)
+
+            if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim:
+                state = [*state, 0]
+            else:
+                state = [*state, self.env.condition_scalar]
+
+            # Num_layers specified in the policy model 
+            h_prev = torch.zeros(size=(1, 1, self.hidden_size))
+
+            ### STEPS PER EPISODE ###
+            for timestep in range(self.env.timestep_limit):
+
+                ### SELECT ACTION ###
+                with torch.no_grad():
+
+                    if self.mode_to_sim in ["neural_pert"]:
+                        state = torch.FloatTensor(state).to(self.agent.device).unsqueeze(0).unsqueeze(0)
+                        h_prev = h_prev.to(self.agent.device)
+                        neural_pert = perturbation_specs.neural_pert[timestep % perturbation_specs.neural_pert.shape[0], :]
+                        neural_pert = torch.FloatTensor(neural_pert).to(self.agent.device).unsqueeze(0).unsqueeze(0) 
+                        action, h_current, rnn_act, rnn_in = self.agent.actor.forward_for_neural_pert(state, h_prev, neural_pert)
+
+                    elif self.mode_to_sim in ["SFE"] and "recurrent_connections" in perturbation_specs.sf_elim:
+                        h_prev = h_prev*0
+                        action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True)
+                    else:
+                        action, h_current, rnn_act, rnn_in = self.agent.select_action(state, h_prev, evaluate=True)
+
+                    
+                    emg_cond.append(action) #[n_muscles, ]
+                    rnn_activity_cond.append(rnn_act[0, :])  # [1, n_hidden_units] --> [n_hidden_units,]
+                    rnn_input_cond.append(state) #[n_inputs, ]
+                    rnn_input_fp_cond.append(rnn_in[0, 0, :])  #[1, 1, n_hidden_units] --> [n_hidden_units, ]
+
+
+                ### TRACKING REWARD + EXPERIENCE TUPLE###
+                next_state, reward, done, _ = self.env.step(action)
+                
+                if self.mode_to_sim in ["SFE"] and "task_scalar" in perturbation_specs.sf_elim:
+                    next_state = [*next_state, 0]
+                else:
+                    next_state = [*next_state, self.env.condition_scalar]
+
+                episode_reward += reward
+
+                #now append the kinematics of the musculo body and the corresponding target
+                kin_mb_t = []
+                kin_mt_t = []
+                for musculo_body in kinematics_preprocessing_specs.musculo_tracking:
+                    kin_mb_t.append(self.env.sim.data.get_body_xpos(musculo_body[0]).copy())   #[3, ]    
+                    kin_mt_t.append(self.env.sim.data.get_body_xpos(musculo_body[1]).copy()) #[3, ]
+
+                kin_mb_cond.append(kin_mb_t)   # kin_mb_t : [n_targets, 3]
+                kin_mt_cond.append(kin_mt_t)   # kin_mt_t : [n_targets, 3]
+
+                ### VISUALIZE MODEL ###
+                if self.visualize == True:
+                    self.env.render()
+
+                state = next_state
+                h_prev = h_current
+            
+            #Append the testing data
+            emg[i_cond_sim] = np.array(emg_cond)  # [timepoints, muscles]
+            rnn_activity[i_cond_sim] = np.array(rnn_activity_cond) # [timepoints, n_hidden_units]
+            rnn_input[i_cond_sim] = np.array(rnn_input_cond)  #[timepoints, n_inputs]
+            rnn_input_fp[i_cond_sim] = np.array(rnn_input_fp_cond)  #[timepoints, n_hidden_units]
+            kin_mb[i_cond_sim] = np.array(kin_mb_cond).transpose(1, 0, 2)   # kin_mb_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3]
+            kin_mt[i_cond_sim] = np.array(kin_mt_cond).transpose(1, 0, 2)   # kin_mt_cond: [timepoints, n_targets, 3] --> [n_targets, timepoints, 3]
+
+            ##Append the jpca data
+            activity_jpca.append(dict(A = rnn_activity_cond))
+            times_jpca.append(dict(times = np.arange(timestep)))    #the timestep is assumed to be 1ms
+            condition_tpoints_jpca.append(self.env.kin_to_sim[self.env.current_cond_to_sim].shape[-1])
+            n_fixedsteps_jpca.append(self.env.n_fixedsteps)
+
+        ### SAVE TESTING STATS ###
+        Test_Data["emg"] = emg
+        Test_Data["rnn_activity"] = rnn_activity
+        Test_Data["rnn_input"] = rnn_input
+        Test_Data["rnn_input_fp"] = rnn_input_fp
+        Test_Data["kinematics_mbodies"] = kin_mb
+        Test_Data["kinematics_mtargets"] = kin_mt
+
+        ### Save the jPCA data
+        Data_jpca = fromarrays([activity_jpca, times_jpca], names=['A', 'times'])
+
+        #save test data
+        with open(save_name + '/test_data.pkl', 'wb') as f:
+            pickle.dump(Test_Data, f)
+
+        #save jpca data
+        savemat(save_name + '/Data_jpca.mat', {'Data' : Data_jpca})
+        savemat(save_name + '/n_fixedsteps_jpca.mat', {'n_fsteps' : n_fixedsteps_jpca})
+        savemat(save_name + '/condition_tpoints_jpca.mat', {'cond_tpoints': condition_tpoints_jpca})
+        
+
+    def train(self):
+
+        """ Train an RNN based SAC agent to follow kinematic trajectory
+        """
+
+        #Load the saved networks from the last training
+        if self.load_saved_nets_for_training:
+            self.load_saved_nets_from_checkpoint(load_best= False)
+
+        ### TRAINING DATA DICTIONARY ###
+        Statistics = {
+            "rewards": [],
+            "steps": [],
+            "policy_loss": [],
+            "critic_loss": []
+        }
+
+        highest_reward = -float("inf") # used for storing highest reward throughout training
+
+        #Average reward across conditions initialization
+        cond_train_count= np.ones((self.env.n_exp_conds,))
+        cond_avg_reward = np.zeros((self.env.n_exp_conds,))
+        cond_cum_reward = np.zeros((self.env.n_exp_conds,))
+        cond_cum_count = np.zeros((self.env.n_exp_conds,))
+
+        ### BEGIN TRAINING ###
+        for episode in range(self.episodes):
+
+            ### Gather Episode Data Variables ###
+            episode_reward = 0          # reward for single episode
+            episode_steps = 0           # steps completed for single episode
+            policy_loss_tracker = []    # stores policy loss throughout episode
+            critic1_loss_tracker = []   # stores critic loss throughout episode
+            done = False                # determines if episode is terminated
+
+            ### GET INITAL STATE + RESET MODEL BY POSE
+            if self.condition_selection_strategy != "reward":
+                cond_to_select = episode % self.env.n_exp_conds
+                state = self.env.reset(cond_to_select)
+
+            else:
+
+                cond_indx = np.nonzero(cond_train_count>0)[0][0]
+                cond_train_count[cond_indx] = cond_train_count[cond_indx] - 1
+                state = self.env.reset(cond_indx)
+
+            #Append the high-level task scalar signal
+            state = [*state, self.env.condition_scalar]
+
+            ep_trajectory = []  # used to store (s_t, a_t, r_t, s_t+1) tuple for replay storage
+
+            h_prev = torch.zeros(size=(1, 1, self.hidden_size)) # num_layers specified in the policy model
+
+            ### LOOP THROUGH EPISODE TIMESTEPS ###
+            while not(done):
+
+                ### SELECT ACTION ###
+                with torch.no_grad():
+                    action, h_current, _, _ = self.agent.select_action(state, h_prev, evaluate=False)
+                    
+                    #Now query the neural activity idx from the simulator
+                    na_idx= self.env.coord_idx
+
+                ### UPDATE MODEL PARAMETERS ###
+                if len(self.policy_memory.buffer) > self.policy_batch_size:
+                    for _ in range(self.batch_iters):
+                        critic_1_loss, critic_2_loss, policy_loss = self.agent.update_parameters(self.policy_memory, self.policy_batch_size)
+                        ### STORE LOSSES ###
+                        policy_loss_tracker.append(policy_loss)
+                        critic1_loss_tracker.append(critic_1_loss)
+
+                ### SIMULATION ###
+                reward = 0
+                for _ in range(self.env.frame_repeat):
+                    next_state, inter_reward, done, _ = self.env.step(action)
+                    next_state = [*next_state, self.env.condition_scalar]
+
+                    reward += inter_reward
+                    episode_steps += 1
+
+                    ### EARLY TERMINATION OF EPISODE ###
+                    if done:
+                        break
+
+                episode_reward += reward
+
+                ### VISUALIZE MODEL ###
+                if self.visualize == True:
+                    self.env.render()
+
+                mask = 1 if episode_steps == self.env._max_episode_steps else float(not done) # ensure mask is not 0 if episode ends
+
+                ### STORE CURRENT TIMESTEP TUPLE ###
+                if not self.env.nusim_data_exists:
+                    self.env.neural_activity[na_idx] = 0
+
+                ep_trajectory.append((state, 
+                                        action, 
+                                        reward, 
+                                        next_state, 
+                                        mask,
+                                        h_current.squeeze(0).cpu().numpy(),
+                                        self.env.neural_activity[na_idx],
+                                        np.array([na_idx])))
+
+                ### MOVE TO NEXT STATE ###
+                state = next_state
+                h_prev = h_current
+
+            
+            ### PUSH TO REPLAY ###
+            self.policy_memory.push(ep_trajectory)
+
+            if self.condition_selection_strategy == "reward":
+
+                cond_cum_reward[cond_indx] = cond_cum_reward[cond_indx] + episode_reward
+                cond_cum_count[cond_indx] = cond_cum_count[cond_indx] + 1
+
+                #Check if there are all zeros in the cond_train_count array
+                if np.all((cond_train_count == 0)):
+                    cond_avg_reward = cond_cum_reward / cond_cum_count
+                    cond_train_count = np.ceil((np.max(cond_avg_reward)*np.ones((self.env.n_exp_conds,)))/cond_avg_reward)
+
+            ### TRACKING ###
+            Statistics["rewards"].append(episode_reward)
+            Statistics["steps"].append(episode_steps)
+            Statistics["policy_loss"].append(np.mean(np.array(policy_loss_tracker)))
+            Statistics["critic_loss"].append(np.mean(np.array(critic1_loss_tracker)))
+
+            ### SAVE DATA TO FILE (in root project folder) ###
+            if len(self.statistics_folder) != 0:
+                np.save(self.statistics_folder + f'/stats_rewards.npy', Statistics['rewards'])
+                np.save(self.statistics_folder + f'/stats_steps.npy', Statistics['steps'])
+                np.save(self.statistics_folder + f'/stats_policy_loss.npy', Statistics['policy_loss'])
+                np.save(self.statistics_folder + f'/stats_critic_loss.npy', Statistics['critic_loss'])
+
+
+            ### SAVING STATE DICT OF TRAINING ###
+            if len(self.checkpoint_folder) != 0 and len(self.checkpoint_file) != 0:
+                if episode % self.save_iter == 0 and len(self.policy_memory.buffer) > self.policy_batch_size:
+                    
+                    #Save the state dicts
+                    torch.save({
+                         'iteration': episode,
+                         'agent_state_dict': self.agent.actor.state_dict(),
+                         'critic_state_dict': self.agent.critic.state_dict(),
+                         'critic_target_state_dict': self.agent.critic_target.state_dict(),
+                         'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(),
+                         'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(),
+                     }, self.checkpoint_folder + f'/{self.checkpoint_file}.pth')
+
+                    #Save the pickled model for fixedpoint finder analysis
+                    torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_fpf.pth')
+
+                
+                if episode_reward > highest_reward:
+                    torch.save({
+                            'iteration': episode,
+                            'agent_state_dict': self.agent.actor.state_dict(),
+                            'critic_state_dict': self.agent.critic.state_dict(),
+                            'critic_target_state_dict': self.agent.critic_target.state_dict(),
+                            'agent_optimizer_state_dict': self.agent.actor_optim.state_dict(),
+                            'critic_optimizer_state_dict': self.agent.critic_optim.state_dict(),
+                        }, self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')
+
+                    #Save the pickled model for fixedpoint finder analysis
+                    torch.save(self.agent.actor.rnn, self.checkpoint_folder + f'/actor_rnn_best_fpf.pth')
+
+            if episode_reward > highest_reward:
+                highest_reward = episode_reward
+
+            ### PRINT TRAINING OUTPUT ###
+            if self.verbose_training:
+                print('-----------------------------------')
+                print('highest reward: {} | reward: {} | timesteps completed: {}'.format(highest_reward, episode_reward, episode_steps))
+                print('-----------------------------------\n')
+
+    def load_saved_nets_from_checkpoint(self, load_best: bool):
+
+        #Load the saved networks from the checkpoint file
+        #Saved networks include policy, critic, critic_target, policy_optimizer and critic_optimizer
+
+        if not load_best:
+
+            #Load the policy network
+            self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_state_dict'])
+
+            #Load the critic network
+            self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_state_dict'])
+
+            #Load the critic target network
+            self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_target_state_dict'])
+
+            #Load the policy optimizer 
+            self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['agent_optimizer_state_dict'])
+
+            #Load the critic optimizer
+            self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}.pth')['critic_optimizer_state_dict'])
+
+        else:
+
+            #Load the policy network
+            self.agent.actor.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_state_dict'])
+
+            #Load the critic network
+            self.agent.critic.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_state_dict'])
+
+            #Load the critic target network
+            self.agent.critic_target.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_target_state_dict'])
+
+            #Load the policy optimizer 
+            self.agent.actor_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['agent_optimizer_state_dict'])
+
+            #Load the critic optimizer
+            self.agent.critic_optim.load_state_dict(torch.load(self.checkpoint_folder + f'/{self.checkpoint_file}_best.pth')['critic_optimizer_state_dict'])