[87e8bf]: / myosuite / logger / roboset_logger.py

Download this file

92 lines (73 with data), 3.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from myosuite.logger.grouped_datasets import Trace, TraceType
import numpy as np
import json
class RoboSet_Trace(Trace):
def __init__(self, name, **kwargs):
super().__init__(name=name, **kwargs)
self.trace_type=TraceType.ROBOSET
# parse path from robohive format into robopen dataset format
def path2dataset(self, path:dict, config_path=None)->dict:
"""
Convert RoboHive format into roboset format
"""
path_keys = path.keys()
dataset = {}
# Data =====
dataset['data/time'] = path['env_infos/obs_dict/time']
# actions
if 'actions' in path.keys():
dataset['data/ctrl_arm'] = path['actions'][:,:7]
dataset['data/ctrl_ee'] = path['actions'][:,7:]
# states
for key in ['qp_arm', 'qv_arm', 'tau_arm', 'qp_ee', 'qv_ee']:
roboset_keyin_path = 'env_infos/obs_dict/'+key
if roboset_keyin_path in path_keys:
dataset['data/'+key] = path[roboset_keyin_path]
# cams
for cam in ['left', 'right', 'top', 'wrist']:
for key in path_keys:
if cam in key:
if 'rgb:' in key:
dataset['data/rgb_'+cam] = path[key]
elif 'd:' in key:
dataset['data/d_'+cam] = path[key]
# user
if 'user' in path_keys:
dataset['data/user'] = path['env_infos/obs_dict/user']
# Derived =====
pose_ee = []
if 'env_infos/obs_dict/pos_ee' in path_keys or 'env_infos/obs_dict/rot_ee' in path_keys:
assert ('env_infos/obs_dict/pos_ee' in path_keys and 'env_infos/obs_dict/rot_ee' in path_keys), "Both pose_ee and rot_ee are required"
dataset['derived/pose_ee'] = np.hstack([path['env_infos/obs_dict/pos_ee'], path['env_infos/obs_dict/rot_ee']])
# Config =====
if config_path:
config = json.load(open(config_path, 'rb'))
dataset['config'] = config
if 'user_cmt' in path.keys():
dataset['config/solved'] = np.array(np.float16(path['user_cmt']))
return dataset
# Save
def save(self,
# save options
trace_name:str,
# compression options
compressions='gzip',
compression_opts=4,
**kwargs
):
# close trace before saving
if not self.verify_stacked_flattened():
print("Closing Trace: "+self.name)
self.close(**kwargs)
# Roboset format
for grp_k, grp_v in self.trace.items():
self.trace[grp_k] = self.path2dataset(grp_v)
super().save(trace_name=trace_name, compressions=compressions, compression_opts=compression_opts, **kwargs)
# Load
def load(self, trace_type, **kwargs):
"""
Ensure that input type is RoboSet format before loading
"""
trace_type=TraceType.get_type(trace_type)
assert trace_type == TraceType.ROBOSET, "RoboSet_Trace requires TraceType.ROBOSET as trace_type"
super().load(trace_type=trace_type, **kwargs)