a b/main.py
1
from SAC.RL_Framework_Mujoco import Muscle_Env    #Muscle_Env is the musculoskeletal environment
2
from simulate import Simulate
3
import config
4
import pdb
5
6
def main():
7
8
    """ Train an agent to control a musculoskeletal system using DRL and RNNs
9
        ---------------------------------------------------------------------
10
    * Simulate object runs training
11
    * Specify whether testing or not in config file, along with the preferred file to save testing statistics
12
    * To ensure model is saved properly, specify your preferred directories for storing state_dict
13
    * A list of possible commands with their functions is given in the config.py file as well as the simulate.py file
14
    * Specify kinematics by choosing kinematics_path in config file, monkey data is currently provided by default
15
    """
16
17
    ### PARAMETERS ###
18
    parser = config.config_parser()
19
    args = parser.parse_args()
20
21
    ### TRAINING OBJECT ###
22
    trainer = Simulate(Muscle_Env, args)
23
24
    # trainer = Simulate(
25
    #     Muscle_Env,
26
    #     args.model,
27
    #     args.gamma,
28
    #     args.tau,
29
    #     args.lr,
30
    #     args.alpha,
31
    #     args.automatic_entropy_tuning,
32
    #     args.seed,
33
    #     args.policy_batch_size,
34
    #     args.hidden_size,
35
    #     args.policy_replay_size,
36
    #     args.multi_policy_loss,
37
    #     args.batch_iters,
38
    #     args.cuda,
39
    #     args.visualize,
40
    #     args.root_dir,
41
    #     args.checkpoint_file,
42
    #     args.checkpoint_folder,
43
    #     args.statistics_folder,
44
    #     args.total_episodes,
45
    #     args.load_saved_nets_for_training,
46
    #     args.save_iter,
47
    #     args.mode,
48
    #     args.musculoskeletal_model_path,
49
    #     args.initial_pose_path,
50
    #     args.kinematics_path,
51
    #     args.nusim_data_path,
52
    #     args.condition_selection_strategy,
53
    #     args.sim_dt,
54
    #     args.frame_repeat,
55
    #     args.n_fixedsteps,
56
    #     args.timestep_limit,
57
    #     args.trajectory_scaling,
58
    #     args.center
59
    # )
60
61
    ### TRAIN OR TEST ###
62
    if args.mode == "train":
63
        trainer.train()
64
    elif args.mode in ["test", "SFE", "sensory_pert", "neural_pert", "musculo_properties"]:
65
        trainer.test(args.test_data_filename)
66
    else:
67
        raise NotImplementedError
68
69
if __name__ == '__main__':
70
    main()