a b/datasets/SUSTech1K/pretreatment_SUSTech1K.py
1
# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
2
import argparse
3
import logging
4
import multiprocessing as mp
5
import os
6
import pickle
7
from collections import defaultdict
8
from functools import partial
9
from pathlib import Path
10
from typing import Tuple
11
12
import cv2
13
import numpy as np
14
from tqdm import tqdm
15
16
import json
17
import open3d as o3d
18
19
def compare_pcd_rgb_timestamp(pcd_file,rgb_file):
20
    pcd_time = float(pcd_file.split('/')[-1].replace('.pcd','')) + 0.05
21
    rgb_time = float(rgb_file.split('/')[-1].replace('.jpg','')[:10] + '.' + rgb_file.split('/')[-1].replace('.jpg','')[10:])
22
    return pcd_time, rgb_time
23
24
25
26
def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
27
    """Reads a group of images and saves the data in pickle format.
28
29
    Args:
30
        img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
31
        output_path (Path): Output path.
32
        img_size (int, optional): Image resizing size. Defaults to 64.
33
        verbose (bool, optional): Display debug info. Defaults to False.
34
    """    
35
    sinfo = img_groups[0]
36
    img_paths = img_groups[1] # path with modality name
37
    to_pickle = []
38
    cnt = 0
39
    pcd_list = []
40
    rgb_list = []
41
42
    threshold = 0.020 # 20 ms
43
44
    for index, modality_files in enumerate(img_paths):
45
        data_files = modality_files[1]
46
        modality = modality_files[0]
47
        if modality == 'PCDs':
48
            data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files]
49
            pcd_list = data_files
50
        elif modality == 'RGB_raw':
51
            imgs = [cv2.imread(rgb) for rgb in data_files]
52
            rgb_list = data_files
53
            imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
54
            HWs =  [img.shape[:2] for img in imgs]
55
            # transpose to (C, H W)
56
            data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs]
57
            imgs = [img.transpose(2, 0, 1) for img in imgs]
58
            data = np.asarray(data)
59
            HWs = np.asarray(HWs)
60
        elif modality == 'Sils_raw':
61
            sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
62
            data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
63
            data = np.asarray(data)
64
        elif modality == 'Sils_aligned':
65
            sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
66
            data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
67
            data = np.asarray(data)
68
        elif modality == 'Pose':
69
            data = [json.load(open(pose)) for pose in data_files]
70
            data = np.asarray(data)
71
        elif modality == 'PCDs_depths':
72
            imgs = [cv2.imread(rgb) for rgb in data_files]
73
            imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
74
            data = [img.transpose(2, 0, 1) for img in imgs]       
75
            data = np.asarray(data)
76
        elif modality == 'PCDs_sils':
77
            data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
78
            data = np.asarray(data)
79
80
        dst_path = os.path.join(output_path, *sinfo)
81
        os.makedirs(dst_path, exist_ok=True)
82
        if modality == 'RGB_raw':
83
            pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-Ratios-HW.pkl')
84
            pickle.dump(HWs, open(pkl_path, 'wb'))   
85
            cnt += 1
86
87
        if 'PCDs' in modality:
88
            pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-LiDAR-{modality}.pkl')
89
            pickle.dump(data, open(pkl_path, 'wb'))   
90
        else:
91
            pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-{modality}.pkl')
92
            pickle.dump(data, open(pkl_path, 'wb'))   
93
        cnt += 1
94
95
    pcd_indexs = []
96
    rgb_indexs = []
97
    # print(pcd_list)
98
    for pcd_index in range(len(pcd_list)):
99
        time_diff = 1
100
        tmp = pcd_index, 0
101
        for rgb_index in range(len(rgb_list)):
102
            pcd_t, rgb_t = compare_pcd_rgb_timestamp(pcd_list[pcd_index], rgb_list[rgb_index])
103
            diff = abs(pcd_t - rgb_t)
104
            if diff < time_diff:
105
                tmp = pcd_index, rgb_index
106
                time_diff = diff
107
        if time_diff <= threshold:
108
            pcd_indexs.append(tmp[0])
109
            rgb_indexs.append(tmp[1])
110
            
111
    if len(set(pcd_indexs)) != len(pcd_indexs):
112
        print(img_groups[0], pcd_indexs, rgb_indexs, len(pcd_indexs) == len(pcd_indexs))
113
114
    for index, modality_files in enumerate(img_paths):
115
        modality = modality_files[0]
116
        data_files = modality_files[1]
117
        data_files = [data_files[index] for index in pcd_indexs] if 'PCDs' in modality else [data_files[index] for index in rgb_indexs]
118
119
        if modality == 'PCDs':
120
            data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files]
121
            pcd_list = data_files
122
        elif modality == 'RGB_raw':
123
            imgs = [cv2.imread(rgb) for rgb in data_files]
124
            rgb_list = data_files
125
            imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
126
            HWs =  [img.shape[:2] for img in imgs]
127
            # transpose to (C, H W)
128
            data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs]
129
            imgs = [img.transpose(2, 0, 1) for img in imgs]
130
            data = np.asarray(data)
131
            HWs = np.asarray(HWs)
132
        elif modality == 'Sils_raw':
133
            sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
134
            data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
135
            data = np.asarray(data)
136
        elif modality == 'Sils_aligned':
137
            sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
138
            data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
139
            data = np.asarray(data)
140
        elif modality == 'Pose':
141
            data = [json.load(open(pose)) for pose in data_files]
142
            data = np.asarray(data)
143
        elif modality == 'PCDs_depths':
144
            imgs = [cv2.imread(rgb) for rgb in data_files]
145
            imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
146
            data = [img.transpose(2, 0, 1) for img in imgs]       
147
            data = np.asarray(data)
148
        elif modality == 'PCDs_sils':
149
            data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
150
            data = np.asarray(data)
151
152
        dst_path = os.path.join(output_path, *sinfo)
153
        os.makedirs(dst_path, exist_ok=True)
154
        if modality == 'RGB_raw':
155
            pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-Ratios-HW.pkl')
156
            pickle.dump(HWs, open(pkl_path, 'wb'))   
157
            cnt += 1
158
159
        if 'PCDs' in modality:
160
            pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-LiDAR-{modality}.pkl')
161
            pickle.dump(data, open(pkl_path, 'wb'))   
162
        else:
163
            pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-{modality}.pkl')
164
            pickle.dump(data, open(pkl_path, 'wb'))   
165
        cnt += 1
166
167
168
def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
169
    """Reads a dataset and saves the data in pickle format.
170
171
    Args:
172
        input_path (Path): Dataset root path.
173
        output_path (Path): Output path.
174
        img_size (int, optional): Image resizing size. Defaults to 64.
175
        workers (int, optional): Number of thread workers. Defaults to 4.
176
        verbose (bool, optional): Display debug info. Defaults to False.
177
    """
178
    img_groups = defaultdict(list)
179
    logging.info(f'Listing {input_path}')
180
    total_files = 0
181
    for id_ in tqdm(sorted(os.listdir(input_path))):    
182
        for type_ in os.listdir(os.path.join(input_path,id_)):
183
            for view_ in os.listdir(os.path.join(input_path,id_,type_)):
184
                for modality in sorted(os.listdir(os.path.join(input_path,id_,type_,view_))):
185
                    modality_path = os.path.join(input_path,id_,type_,view_,modality)
186
                    file_names = sorted(os.listdir(modality_path))
187
                    file_names = [os.path.join(modality_path, file_name) for file_name in file_names]
188
                    img_groups[(id_, type_, view_)].append((modality, file_names))
189
                    total_files += 1
190
191
    logging.info(f'Total files listed: {total_files}')
192
193
    progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
194
195
    with mp.Pool(workers) as pool:
196
        logging.info(f'Start pretreating {input_path}')
197
        for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
198
            progress.update(1)
199
    logging.info('Done')
200
201
202
if __name__ == '__main__':
203
    parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
204
    parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
205
    parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
206
    parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
207
    parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
208
    parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
209
    parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
210
    parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
211
    args = parser.parse_args()
212
213
    logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
214
    
215
    if args.verbose:
216
        logging.getLogger().setLevel(logging.DEBUG)
217
        logging.info('Verbose mode is on.')
218
        for k, v in args.__dict__.items():
219
            logging.debug(f'{k}: {v}')
220
221
    pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)