Diff of /opengait/data/dataset.py [000000] .. [fd9ef4]

Switch to unified view

a b/opengait/data/dataset.py
1
import os
2
import pickle
3
import os.path as osp
4
import torch.utils.data as tordata
5
import json
6
from utils import get_msg_mgr
7
8
9
class DataSet(tordata.Dataset):
10
    def __init__(self, data_cfg, training):
11
        """
12
            seqs_info: the list with each element indicating 
13
                            a certain gait sequence presented as [label, type, view, paths];
14
        """
15
        self.__dataset_parser(data_cfg, training)
16
        self.cache = data_cfg['cache']
17
        self.label_list = [seq_info[0] for seq_info in self.seqs_info]
18
        self.types_list = [seq_info[1] for seq_info in self.seqs_info]
19
        self.views_list = [seq_info[2] for seq_info in self.seqs_info]
20
21
        self.label_set = sorted(list(set(self.label_list)))
22
        self.types_set = sorted(list(set(self.types_list)))
23
        self.views_set = sorted(list(set(self.views_list)))
24
        self.seqs_data = [None] * len(self)
25
        self.indices_dict = {label: [] for label in self.label_set}
26
        for i, seq_info in enumerate(self.seqs_info):
27
            self.indices_dict[seq_info[0]].append(i)
28
        if self.cache:
29
            self.__load_all_data()
30
31
    def __len__(self):
32
        return len(self.seqs_info)
33
34
    def __loader__(self, paths):
35
        paths = sorted(paths)
36
        data_list = []
37
        for pth in paths:
38
            if pth.endswith('.pkl'):
39
                with open(pth, 'rb') as f:
40
                    _ = pickle.load(f)
41
                f.close()
42
            else:
43
                raise ValueError('- Loader - just support .pkl !!!')
44
            data_list.append(_)
45
        for idx, data in enumerate(data_list):
46
            if len(data) != len(data_list[0]):
47
                raise ValueError(
48
                    'Each input data({}) should have the same length.'.format(paths[idx]))
49
            if len(data) == 0:
50
                raise ValueError(
51
                    'Each input data({}) should have at least one element.'.format(paths[idx]))
52
        return data_list
53
54
    def __getitem__(self, idx):
55
        if not self.cache:
56
            data_list = self.__loader__(self.seqs_info[idx][-1])
57
        elif self.seqs_data[idx] is None:
58
            data_list = self.__loader__(self.seqs_info[idx][-1])
59
            self.seqs_data[idx] = data_list
60
        else:
61
            data_list = self.seqs_data[idx]
62
        seq_info = self.seqs_info[idx]
63
        return data_list, seq_info
64
65
    def __load_all_data(self):
66
        for idx in range(len(self)):
67
            self.__getitem__(idx)
68
69
    def __dataset_parser(self, data_config, training):
70
        dataset_root = data_config['dataset_root']
71
        try:
72
            data_in_use = data_config['data_in_use']  # [n], true or false
73
        except:
74
            data_in_use = None
75
76
        with open(data_config['dataset_partition'], "rb") as f:
77
            partition = json.load(f)
78
        train_set = partition["TRAIN_SET"]
79
        test_set = partition["TEST_SET"]
80
        label_list = os.listdir(dataset_root)
81
        train_set = [label for label in train_set if label in label_list]
82
        test_set = [label for label in test_set if label in label_list]
83
        miss_pids = [label for label in label_list if label not in (
84
            train_set + test_set)]
85
        msg_mgr = get_msg_mgr()
86
87
        def log_pid_list(pid_list):
88
            if len(pid_list) >= 3:
89
                msg_mgr.log_info('[%s, %s, ..., %s]' %
90
                                 (pid_list[0], pid_list[1], pid_list[-1]))
91
            else:
92
                msg_mgr.log_info(pid_list)
93
94
        if len(miss_pids) > 0:
95
            msg_mgr.log_debug('-------- Miss Pid List --------')
96
            msg_mgr.log_debug(miss_pids)
97
        if training:
98
            msg_mgr.log_info("-------- Train Pid List --------")
99
            log_pid_list(train_set)
100
        else:
101
            msg_mgr.log_info("-------- Test Pid List --------")
102
            log_pid_list(test_set)
103
104
        def get_seqs_info_list(label_set):
105
            seqs_info_list = []
106
            for lab in label_set:
107
                for typ in sorted(os.listdir(osp.join(dataset_root, lab))):
108
                    for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))):
109
                        seq_info = [lab, typ, vie]
110
                        seq_path = osp.join(dataset_root, *seq_info)
111
                        seq_dirs = sorted(os.listdir(seq_path))
112
                        if seq_dirs != []:
113
                            seq_dirs = [osp.join(seq_path, dir)
114
                                        for dir in seq_dirs]
115
                            if data_in_use is not None:
116
                                seq_dirs = [dir for dir, use_bl in zip(
117
                                    seq_dirs, data_in_use) if use_bl]
118
                            seqs_info_list.append([*seq_info, seq_dirs])
119
                        else:
120
                            msg_mgr.log_debug(
121
                                'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie))
122
            return seqs_info_list
123
124
        self.seqs_info = get_seqs_info_list(
125
            train_set) if training else get_seqs_info_list(test_set)