|
a |
|
b/opengait/data/dataset.py |
|
|
1 |
import os |
|
|
2 |
import pickle |
|
|
3 |
import os.path as osp |
|
|
4 |
import torch.utils.data as tordata |
|
|
5 |
import json |
|
|
6 |
from utils import get_msg_mgr |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class DataSet(tordata.Dataset): |
|
|
10 |
def __init__(self, data_cfg, training): |
|
|
11 |
""" |
|
|
12 |
seqs_info: the list with each element indicating |
|
|
13 |
a certain gait sequence presented as [label, type, view, paths]; |
|
|
14 |
""" |
|
|
15 |
self.__dataset_parser(data_cfg, training) |
|
|
16 |
self.cache = data_cfg['cache'] |
|
|
17 |
self.label_list = [seq_info[0] for seq_info in self.seqs_info] |
|
|
18 |
self.types_list = [seq_info[1] for seq_info in self.seqs_info] |
|
|
19 |
self.views_list = [seq_info[2] for seq_info in self.seqs_info] |
|
|
20 |
|
|
|
21 |
self.label_set = sorted(list(set(self.label_list))) |
|
|
22 |
self.types_set = sorted(list(set(self.types_list))) |
|
|
23 |
self.views_set = sorted(list(set(self.views_list))) |
|
|
24 |
self.seqs_data = [None] * len(self) |
|
|
25 |
self.indices_dict = {label: [] for label in self.label_set} |
|
|
26 |
for i, seq_info in enumerate(self.seqs_info): |
|
|
27 |
self.indices_dict[seq_info[0]].append(i) |
|
|
28 |
if self.cache: |
|
|
29 |
self.__load_all_data() |
|
|
30 |
|
|
|
31 |
def __len__(self): |
|
|
32 |
return len(self.seqs_info) |
|
|
33 |
|
|
|
34 |
def __loader__(self, paths): |
|
|
35 |
paths = sorted(paths) |
|
|
36 |
data_list = [] |
|
|
37 |
for pth in paths: |
|
|
38 |
if pth.endswith('.pkl'): |
|
|
39 |
with open(pth, 'rb') as f: |
|
|
40 |
_ = pickle.load(f) |
|
|
41 |
f.close() |
|
|
42 |
else: |
|
|
43 |
raise ValueError('- Loader - just support .pkl !!!') |
|
|
44 |
data_list.append(_) |
|
|
45 |
for idx, data in enumerate(data_list): |
|
|
46 |
if len(data) != len(data_list[0]): |
|
|
47 |
raise ValueError( |
|
|
48 |
'Each input data({}) should have the same length.'.format(paths[idx])) |
|
|
49 |
if len(data) == 0: |
|
|
50 |
raise ValueError( |
|
|
51 |
'Each input data({}) should have at least one element.'.format(paths[idx])) |
|
|
52 |
return data_list |
|
|
53 |
|
|
|
54 |
def __getitem__(self, idx): |
|
|
55 |
if not self.cache: |
|
|
56 |
data_list = self.__loader__(self.seqs_info[idx][-1]) |
|
|
57 |
elif self.seqs_data[idx] is None: |
|
|
58 |
data_list = self.__loader__(self.seqs_info[idx][-1]) |
|
|
59 |
self.seqs_data[idx] = data_list |
|
|
60 |
else: |
|
|
61 |
data_list = self.seqs_data[idx] |
|
|
62 |
seq_info = self.seqs_info[idx] |
|
|
63 |
return data_list, seq_info |
|
|
64 |
|
|
|
65 |
def __load_all_data(self): |
|
|
66 |
for idx in range(len(self)): |
|
|
67 |
self.__getitem__(idx) |
|
|
68 |
|
|
|
69 |
def __dataset_parser(self, data_config, training): |
|
|
70 |
dataset_root = data_config['dataset_root'] |
|
|
71 |
try: |
|
|
72 |
data_in_use = data_config['data_in_use'] # [n], true or false |
|
|
73 |
except: |
|
|
74 |
data_in_use = None |
|
|
75 |
|
|
|
76 |
with open(data_config['dataset_partition'], "rb") as f: |
|
|
77 |
partition = json.load(f) |
|
|
78 |
train_set = partition["TRAIN_SET"] |
|
|
79 |
test_set = partition["TEST_SET"] |
|
|
80 |
label_list = os.listdir(dataset_root) |
|
|
81 |
train_set = [label for label in train_set if label in label_list] |
|
|
82 |
test_set = [label for label in test_set if label in label_list] |
|
|
83 |
miss_pids = [label for label in label_list if label not in ( |
|
|
84 |
train_set + test_set)] |
|
|
85 |
msg_mgr = get_msg_mgr() |
|
|
86 |
|
|
|
87 |
def log_pid_list(pid_list): |
|
|
88 |
if len(pid_list) >= 3: |
|
|
89 |
msg_mgr.log_info('[%s, %s, ..., %s]' % |
|
|
90 |
(pid_list[0], pid_list[1], pid_list[-1])) |
|
|
91 |
else: |
|
|
92 |
msg_mgr.log_info(pid_list) |
|
|
93 |
|
|
|
94 |
if len(miss_pids) > 0: |
|
|
95 |
msg_mgr.log_debug('-------- Miss Pid List --------') |
|
|
96 |
msg_mgr.log_debug(miss_pids) |
|
|
97 |
if training: |
|
|
98 |
msg_mgr.log_info("-------- Train Pid List --------") |
|
|
99 |
log_pid_list(train_set) |
|
|
100 |
else: |
|
|
101 |
msg_mgr.log_info("-------- Test Pid List --------") |
|
|
102 |
log_pid_list(test_set) |
|
|
103 |
|
|
|
104 |
def get_seqs_info_list(label_set): |
|
|
105 |
seqs_info_list = [] |
|
|
106 |
for lab in label_set: |
|
|
107 |
for typ in sorted(os.listdir(osp.join(dataset_root, lab))): |
|
|
108 |
for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))): |
|
|
109 |
seq_info = [lab, typ, vie] |
|
|
110 |
seq_path = osp.join(dataset_root, *seq_info) |
|
|
111 |
seq_dirs = sorted(os.listdir(seq_path)) |
|
|
112 |
if seq_dirs != []: |
|
|
113 |
seq_dirs = [osp.join(seq_path, dir) |
|
|
114 |
for dir in seq_dirs] |
|
|
115 |
if data_in_use is not None: |
|
|
116 |
seq_dirs = [dir for dir, use_bl in zip( |
|
|
117 |
seq_dirs, data_in_use) if use_bl] |
|
|
118 |
seqs_info_list.append([*seq_info, seq_dirs]) |
|
|
119 |
else: |
|
|
120 |
msg_mgr.log_debug( |
|
|
121 |
'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie)) |
|
|
122 |
return seqs_info_list |
|
|
123 |
|
|
|
124 |
self.seqs_info = get_seqs_info_list( |
|
|
125 |
train_set) if training else get_seqs_info_list(test_set) |