|
a |
|
b/main.py |
|
|
1 |
from SAC.RL_Framework_Mujoco import Muscle_Env #Muscle_Env is the musculoskeletal environment |
|
|
2 |
from simulate import Simulate |
|
|
3 |
import config |
|
|
4 |
import pdb |
|
|
5 |
|
|
|
6 |
def main(): |
|
|
7 |
|
|
|
8 |
""" Train an agent to control a musculoskeletal system using DRL and RNNs |
|
|
9 |
--------------------------------------------------------------------- |
|
|
10 |
* Simulate object runs training |
|
|
11 |
* Specify whether testing or not in config file, along with the preferred file to save testing statistics |
|
|
12 |
* To ensure model is saved properly, specify your preferred directories for storing state_dict |
|
|
13 |
* A list of possible commands with their functions is given in the config.py file as well as the simulate.py file |
|
|
14 |
* Specify kinematics by choosing kinematics_path in config file, monkey data is currently provided by default |
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
### PARAMETERS ### |
|
|
18 |
parser = config.config_parser() |
|
|
19 |
args = parser.parse_args() |
|
|
20 |
|
|
|
21 |
### TRAINING OBJECT ### |
|
|
22 |
trainer = Simulate(Muscle_Env, args) |
|
|
23 |
|
|
|
24 |
# trainer = Simulate( |
|
|
25 |
# Muscle_Env, |
|
|
26 |
# args.model, |
|
|
27 |
# args.gamma, |
|
|
28 |
# args.tau, |
|
|
29 |
# args.lr, |
|
|
30 |
# args.alpha, |
|
|
31 |
# args.automatic_entropy_tuning, |
|
|
32 |
# args.seed, |
|
|
33 |
# args.policy_batch_size, |
|
|
34 |
# args.hidden_size, |
|
|
35 |
# args.policy_replay_size, |
|
|
36 |
# args.multi_policy_loss, |
|
|
37 |
# args.batch_iters, |
|
|
38 |
# args.cuda, |
|
|
39 |
# args.visualize, |
|
|
40 |
# args.root_dir, |
|
|
41 |
# args.checkpoint_file, |
|
|
42 |
# args.checkpoint_folder, |
|
|
43 |
# args.statistics_folder, |
|
|
44 |
# args.total_episodes, |
|
|
45 |
# args.load_saved_nets_for_training, |
|
|
46 |
# args.save_iter, |
|
|
47 |
# args.mode, |
|
|
48 |
# args.musculoskeletal_model_path, |
|
|
49 |
# args.initial_pose_path, |
|
|
50 |
# args.kinematics_path, |
|
|
51 |
# args.nusim_data_path, |
|
|
52 |
# args.condition_selection_strategy, |
|
|
53 |
# args.sim_dt, |
|
|
54 |
# args.frame_repeat, |
|
|
55 |
# args.n_fixedsteps, |
|
|
56 |
# args.timestep_limit, |
|
|
57 |
# args.trajectory_scaling, |
|
|
58 |
# args.center |
|
|
59 |
# ) |
|
|
60 |
|
|
|
61 |
### TRAIN OR TEST ### |
|
|
62 |
if args.mode == "train": |
|
|
63 |
trainer.train() |
|
|
64 |
elif args.mode in ["test", "SFE", "sensory_pert", "neural_pert", "musculo_properties"]: |
|
|
65 |
trainer.test(args.test_data_filename) |
|
|
66 |
else: |
|
|
67 |
raise NotImplementedError |
|
|
68 |
|
|
|
69 |
if __name__ == '__main__': |
|
|
70 |
main() |