[fd9ef4]: / opengait / data / dataset.py

Download this file

126 lines (112 with data), 5.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import pickle
import os.path as osp
import torch.utils.data as tordata
import json
from utils import get_msg_mgr
class DataSet(tordata.Dataset):
def __init__(self, data_cfg, training):
"""
seqs_info: the list with each element indicating
a certain gait sequence presented as [label, type, view, paths];
"""
self.__dataset_parser(data_cfg, training)
self.cache = data_cfg['cache']
self.label_list = [seq_info[0] for seq_info in self.seqs_info]
self.types_list = [seq_info[1] for seq_info in self.seqs_info]
self.views_list = [seq_info[2] for seq_info in self.seqs_info]
self.label_set = sorted(list(set(self.label_list)))
self.types_set = sorted(list(set(self.types_list)))
self.views_set = sorted(list(set(self.views_list)))
self.seqs_data = [None] * len(self)
self.indices_dict = {label: [] for label in self.label_set}
for i, seq_info in enumerate(self.seqs_info):
self.indices_dict[seq_info[0]].append(i)
if self.cache:
self.__load_all_data()
def __len__(self):
return len(self.seqs_info)
def __loader__(self, paths):
paths = sorted(paths)
data_list = []
for pth in paths:
if pth.endswith('.pkl'):
with open(pth, 'rb') as f:
_ = pickle.load(f)
f.close()
else:
raise ValueError('- Loader - just support .pkl !!!')
data_list.append(_)
for idx, data in enumerate(data_list):
if len(data) != len(data_list[0]):
raise ValueError(
'Each input data({}) should have the same length.'.format(paths[idx]))
if len(data) == 0:
raise ValueError(
'Each input data({}) should have at least one element.'.format(paths[idx]))
return data_list
def __getitem__(self, idx):
if not self.cache:
data_list = self.__loader__(self.seqs_info[idx][-1])
elif self.seqs_data[idx] is None:
data_list = self.__loader__(self.seqs_info[idx][-1])
self.seqs_data[idx] = data_list
else:
data_list = self.seqs_data[idx]
seq_info = self.seqs_info[idx]
return data_list, seq_info
def __load_all_data(self):
for idx in range(len(self)):
self.__getitem__(idx)
def __dataset_parser(self, data_config, training):
dataset_root = data_config['dataset_root']
try:
data_in_use = data_config['data_in_use'] # [n], true or false
except:
data_in_use = None
with open(data_config['dataset_partition'], "rb") as f:
partition = json.load(f)
train_set = partition["TRAIN_SET"]
test_set = partition["TEST_SET"]
label_list = os.listdir(dataset_root)
train_set = [label for label in train_set if label in label_list]
test_set = [label for label in test_set if label in label_list]
miss_pids = [label for label in label_list if label not in (
train_set + test_set)]
msg_mgr = get_msg_mgr()
def log_pid_list(pid_list):
if len(pid_list) >= 3:
msg_mgr.log_info('[%s, %s, ..., %s]' %
(pid_list[0], pid_list[1], pid_list[-1]))
else:
msg_mgr.log_info(pid_list)
if len(miss_pids) > 0:
msg_mgr.log_debug('-------- Miss Pid List --------')
msg_mgr.log_debug(miss_pids)
if training:
msg_mgr.log_info("-------- Train Pid List --------")
log_pid_list(train_set)
else:
msg_mgr.log_info("-------- Test Pid List --------")
log_pid_list(test_set)
def get_seqs_info_list(label_set):
seqs_info_list = []
for lab in label_set:
for typ in sorted(os.listdir(osp.join(dataset_root, lab))):
for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))):
seq_info = [lab, typ, vie]
seq_path = osp.join(dataset_root, *seq_info)
seq_dirs = sorted(os.listdir(seq_path))
if seq_dirs != []:
seq_dirs = [osp.join(seq_path, dir)
for dir in seq_dirs]
if data_in_use is not None:
seq_dirs = [dir for dir, use_bl in zip(
seq_dirs, data_in_use) if use_bl]
seqs_info_list.append([*seq_info, seq_dirs])
else:
msg_mgr.log_debug(
'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie))
return seqs_info_list
self.seqs_info = get_seqs_info_list(
train_set) if training else get_seqs_info_list(test_set)