[9f010e]: / main.py

Download this file

71 lines (63 with data), 2.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from SAC.RL_Framework_Mujoco import Muscle_Env #Muscle_Env is the musculoskeletal environment
from simulate import Simulate
import config
import pdb
def main():
""" Train an agent to control a musculoskeletal system using DRL and RNNs
---------------------------------------------------------------------
* Simulate object runs training
* Specify whether testing or not in config file, along with the preferred file to save testing statistics
* To ensure model is saved properly, specify your preferred directories for storing state_dict
* A list of possible commands with their functions is given in the config.py file as well as the simulate.py file
* Specify kinematics by choosing kinematics_path in config file, monkey data is currently provided by default
"""
### PARAMETERS ###
parser = config.config_parser()
args = parser.parse_args()
### TRAINING OBJECT ###
trainer = Simulate(Muscle_Env, args)
# trainer = Simulate(
# Muscle_Env,
# args.model,
# args.gamma,
# args.tau,
# args.lr,
# args.alpha,
# args.automatic_entropy_tuning,
# args.seed,
# args.policy_batch_size,
# args.hidden_size,
# args.policy_replay_size,
# args.multi_policy_loss,
# args.batch_iters,
# args.cuda,
# args.visualize,
# args.root_dir,
# args.checkpoint_file,
# args.checkpoint_folder,
# args.statistics_folder,
# args.total_episodes,
# args.load_saved_nets_for_training,
# args.save_iter,
# args.mode,
# args.musculoskeletal_model_path,
# args.initial_pose_path,
# args.kinematics_path,
# args.nusim_data_path,
# args.condition_selection_strategy,
# args.sim_dt,
# args.frame_repeat,
# args.n_fixedsteps,
# args.timestep_limit,
# args.trajectory_scaling,
# args.center
# )
### TRAIN OR TEST ###
if args.mode == "train":
trainer.train()
elif args.mode in ["test", "SFE", "sensory_pert", "neural_pert", "musculo_properties"]:
trainer.test(args.test_data_filename)
else:
raise NotImplementedError
if __name__ == '__main__':
main()