Diff of /append_musculo_targets.py [000000] .. [9f010e]

Switch to unified view

a b/append_musculo_targets.py
1
import config
2
import numpy as np
3
import pickle 
4
from xml.etree import ElementTree as ET
5
from SAC import kinematics_preprocessing_specs 
6
7
### PARAMETERS ###
8
parser = config.config_parser()
9
args = parser.parse_args()
10
11
#Parse the musculo xml
12
tree = ET.parse(args.musculoskeletal_model_path)
13
root = tree.getroot()
14
15
#Load the kinematics to determine the number of targets
16
with open(args.kinematics_path + '/kinematics.pkl', 'rb') as f:
17
            kin = pickle.load(f)    #[num_conditions][n_targets, num_coords, timepoints]
18
19
kinematics_train = kin['train']
20
num_conds = len(kinematics_train)
21
22
#Randomly sample a condition
23
cond_sampled = np.random.randint(0, num_conds)
24
num_targets = kinematics_train[cond_sampled].shape[0]
25
26
for child in root:
27
28
    if child.tag == 'worldbody':
29
        for current_target in range(num_targets):
30
            body_to_add = f'<body name="target{current_target}" pos="0.1 0.1 0.85"> \n'
31
            body_to_add += '<geom size="0.01 0.01 0.01" type="sphere"/> \n'
32
33
            if kinematics_preprocessing_specs.xyz_target[current_target][0]:
34
                body_to_add += f'<joint axis="1 0 0" name="box:x{current_target}" type="slide" limited="false" range="-6.28319  6.28319"></joint> \n'
35
            
36
            if kinematics_preprocessing_specs.xyz_target[current_target][1]:
37
                body_to_add += f'<joint axis="0 0 1" name="box:y{current_target}" type="slide" limited="false" range="-6.28319  6.28319"></joint> \n'
38
39
            if kinematics_preprocessing_specs.xyz_target[current_target][2]:
40
                body_to_add += f'<joint axis="0 1 0" name="box:z{current_target}" type="slide" limited="false" range="-6.28319  6.28319"></joint> \n'
41
42
            body_to_add += '</body>'
43
44
            child.append(ET.fromstring(body_to_add))
45
46
47
#Now save the updated musculoskeletal model with target bodies added
48
tree.write(args.musculoskeletal_model_path[:-len('musculoskeletal_model.xml')] + 'musculo_targets.xml')